Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
([#38139](https://github.com/apache/beam/issues/38139)).
* (Python) Added type alias for with_exception_handling to be used for typehints. ([#38173](https://github.com/apache/beam/issues/38173)).
* Added plugin mechanism to support different Lineage implementations (Java) ([#36790](https://github.com/apache/beam/issues/36790)).
* (Python) Added [Qdrant](https://qdrant.tech/) VectorDatabaseWriteConfig implementation ([#38141](https://github.com/apache/beam/issues/38141)).

## Breaking Changes

Expand Down
212 changes: 212 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional

try:
from qdrant_client import QdrantClient, models
except ImportError:
logging.warning("Qdrant client library is not installed.")

import apache_beam as beam
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.types import EmbeddableItem

DEFAULT_WRITE_BATCH_SIZE = 1000


@dataclass
class QdrantConnectionParameters:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A docstring outlining each field and the mandatory information required to create a valid set of parameters will make this much more user friendly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! added docstring

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add classmethod factories to make it clearer which combinations of parameters are valid? Something like

@dataclass
class QdrantConnectionParameters:
    # ... existing fields unchanged ...

    @classmethod
    def for_cloud(
        cls,
        url: str,
        api_key: str,
        *,
        prefer_grpc: bool = False,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> "QdrantConnectionParameters":
        """Connect to Qdrant Cloud. Requires the cluster URL and an API key."""
        return cls(
            url=url,
            api_key=api_key,
            https=True,
            prefer_grpc=prefer_grpc,
            timeout=timeout,
            kwargs=kwargs,
        )

    @classmethod
    def for_host(
        cls,
        host: str,
        port: int = 6333,
        *,
        grpc_port: int = 6334,
        prefer_grpc: bool = False,
        https: bool = False,
        api_key: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> "QdrantConnectionParameters":
        """Connect to a self-hosted Qdrant instance by host and port."""
        return cls(
            host=host, port=port, grpc_port=grpc_port,
            prefer_grpc=prefer_grpc, https=https,
            api_key=api_key, timeout=timeout, kwargs=kwargs,
        )

    @classmethod
    def for_url(
        cls,
        url: str,
        *,
        api_key: Optional[str] = None,
        prefer_grpc: bool = False,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> "QdrantConnectionParameters":
        """Connect using a full URL like 'https://my-qdrant.example.com:6333'."""
        return cls(url=url, api_key=api_key, prefer_grpc=prefer_grpc,
                   timeout=timeout, kwargs=kwargs)

    @classmethod
    def local(cls, path: str) -> "QdrantConnectionParameters":
        """Use an embedded Qdrant instance persisted to the given path."""
        return cls(path=path)

    @classmethod
    def in_memory(cls) -> "QdrantConnectionParameters":
        """Use an embedded in-memory Qdrant instance. Useful for tests."""
        return cls(location=":memory:")

"""Configuration parameters for connecting to Qdrant service.

Either `location`, `url`, `host`, or `path` must be provided to establish
a connection.

Args:
location:
If `str` - use it as a `url` parameter.
If `None` - use default values for `host` and `port`.
url: either host or str of "<scheme>//<host>:<port>/<prefix>".
Default: `None`
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc: If `true` - use gPRC interface whenever possible.
https: If `true` - use HTTPS(SSL) protocol. Default: `None`
api_key: API key for authentication in Qdrant Cloud. Default: `None`
prefix:
If not `None` - add `prefix` to the REST URL path.
Example: `service/v1` will result in
`http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.
Default: `None`
timeout:
Timeout for REST and gRPC API requests.
Default: 5 seconds for REST and unlimited for gRPC
host:
Host name of Qdrant service.
If url and host are None, set to 'localhost'.
Default: `None`
path: Persistence path for QdrantLocal. Default: `None`
**kwargs: Additional arguments passed directly into client initialization
"""

location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
grpc_port: int = 6334
prefer_grpc: bool = False
https: Optional[bool] = None
api_key: Optional[str] = None
prefix: Optional[str] = None
timeout: Optional[int] = None
host: Optional[str] = None
path: Optional[str] = None
kwargs: dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if not (self.location or self.url or self.host or self.path):
raise ValueError(
"One of location, url, host, or path must be provided for Qdrant")


@dataclass
class QdrantWriteConfig(VectorDatabaseWriteConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar thought here, a docstring should be provided here since this is the entrypoint for users to drop the qdrant write into their pipelines

"""Configuration for writing to Qdrant vector database.

This class defines the parameters needed to write data to a qdrant collection,
including collection targeting, batching behavior, and operation timeouts.

Args:
connection_params: QdrantConnectionParameters with connection settings.
collection_name: Name of the Qdrant collection to write to.
timeout: Optional timeout for write operations in seconds. Default is None.
batch_size: Number of points to write in each batch. Default is 1000.
kwargs: Additional keyword arguments to pass to the client's upsert method.
dense_embedding_key: name for the dense vector in the qdrant collection.
sparse_embedding_key: name for the sparse vector in the qdrant collection.
"""

connection_params: QdrantConnectionParameters
collection_name: str
timeout: Optional[float] = None
batch_size: int = DEFAULT_WRITE_BATCH_SIZE
kwargs: dict[str, Any] = field(default_factory=dict)
dense_embedding_key: str = "dense"
sparse_embedding_key: str = "sparse"

def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
Comment on lines +112 to +114
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to validate that batch_size is a positive integer in __post_init__ to prevent potential issues with empty or negative batch sizes during ingestion.

Suggested change
def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
def __post_init__(self):
if not self.collection_name:
raise ValueError("Collection name must be provided")
if self.batch_size <= 0:
raise ValueError("batch_size must be positive")


def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]:
return _QdrantWriteTransform(self)

def create_converter(
self,
) -> Callable[[EmbeddableItem], "models.PointStruct"]:
def convert(item: EmbeddableItem) -> "models.PointStruct":
if item.dense_embedding is None and item.sparse_embedding is None:
raise ValueError(
"EmbeddableItem must have at least one embedding (dense or sparse)")
vector = {}
if item.dense_embedding is not None:
vector[self.dense_embedding_key] = item.dense_embedding
if item.sparse_embedding is not None:
sparse_indices, sparse_values = item.sparse_embedding
vector[self.sparse_embedding_key] = models.SparseVector(
indices=sparse_indices,
values=sparse_values,
)
id = (
int(item.id)
if isinstance(item.id, str) and item.id.isdigit() else item.id)
return models.PointStruct(
id=id,
vector=vector,
payload=item.metadata if item.metadata else None,
)

return convert


class _QdrantWriteTransform(beam.PTransform):
def __init__(self, config: QdrantWriteConfig):
self.config = config

def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]):
return (
input_or_inputs
| "Convert to Records" >> beam.Map(self.config.create_converter())
| beam.ParDo(_QdrantWriteFn(self.config)))


class _QdrantWriteFn(beam.DoFn):
def __init__(self, config: QdrantWriteConfig):
self.config = config
self._batch = []
self._client: "Optional[QdrantClient]" = None
Comment on lines +159 to +162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In Apache Beam, DoFn instances can be reused across bundles. To ensure that state is not leaked between bundles (especially in case of retries), it is a best practice to initialize bundle-specific state like self._batch in the start_bundle method rather than in __init__.

Suggested change
def __init__(self, config: QdrantWriteConfig):
self.config = config
self._batch = []
self._client: "Optional[QdrantClient]" = None
def __init__(self, config: QdrantWriteConfig):
self.config = config
self._client: "Optional[QdrantClient]" = None
def start_bundle(self):
self._batch = []


def process(self, element, *args, **kwargs):
self._batch.append(element)
if len(self._batch) >= self.config.batch_size:
self._flush()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a byte size limit for individual batches, similar to BigQuery streaming inserts

> self._max_insert_payload_size) or


def setup(self):
params = self.config.connection_params
self._client = QdrantClient(
location=params.location,
url=params.url,
port=params.port,
grpc_port=params.grpc_port,
prefer_grpc=params.prefer_grpc,
https=params.https,
api_key=params.api_key,
prefix=params.prefix,
timeout=params.timeout,
host=params.host,
path=params.path,
check_compatibility=False,
**params.kwargs,
Comment on lines +183 to +184
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Disabling the compatibility check (check_compatibility=False) can lead to difficult-to-debug issues if there is a version mismatch between the client and the Qdrant server. Unless there is a specific reason to disable it, it is safer to leave it enabled (which is the default).

Suggested change
check_compatibility=False,
**params.kwargs,
**params.kwargs,

)

def teardown(self):
if self._client:
self._client.close()
self._client = None

def finish_bundle(self):
self._flush()

def _flush(self):
if len(self._batch) == 0:
return
if not self._client:
raise RuntimeError("Qdrant client is not initialized")
self._client.upsert(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any retriable errors that we should handle?

collection_name=self.config.collection_name,
points=self._batch,
timeout=self.config.timeout,
**self.config.kwargs,
)
self._batch = []

def display_data(self):
res = super().display_data()
res["collection"] = self.config.collection_name
res["batch_size"] = self.config.batch_size
return res
Loading
Loading