diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4001aee9d..484f0e48b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -222,21 +222,6 @@ jobs: runs-on: ${{ matrix.os }} env: SKIP_SPOTLESS_CHECK: true - services: - elasticsearch: - image: docker.elastic.co/elasticsearch/elasticsearch:8.19.0 - env: - discovery.type: single-node - xpack.security.enabled: false - ES_JAVA_OPTS: "-Xms512m -Xmx512m" - ports: - - 9200:9200 - options: >- - --health-cmd "curl -f http://localhost:9200/_cluster/health || exit 1" - --health-interval 10s - --health-timeout 5s - --health-retries 10 - --health-start-period 30s strategy: fail-fast: false matrix: @@ -273,9 +258,26 @@ jobs: run: bash tools/build.sh - name: Install ollama run: bash tools/start_ollama_server.sh + - name: Start Elasticsearch + run: | + docker compose -f tools/docker/elasticsearch/docker-compose.yml down -v + docker compose -f tools/docker/elasticsearch/docker-compose.yml up -d + timeout 180 bash -c 'until curl -fsS http://localhost:9200/_cluster/health; do sleep 5; done' + - name: Start Milvus + run: | + docker compose -f tools/docker/milvus/docker-compose.yml down -v + docker compose -f tools/docker/milvus/docker-compose.yml up -d + timeout 180 bash -c 'until curl -fsS http://localhost:9091/healthz; do sleep 5; done' - name: Run e2e tests env: LOG_LEVEL: INFO run: | export ES_HOST="http://localhost:9200" - tools/e2e.sh \ No newline at end of file + export MILVUS_URI="http://localhost:19530" + tools/e2e.sh + - name: Stop Milvus + if: always() + run: docker compose -f tools/docker/milvus/docker-compose.yml down -v + - name: Stop Elasticsearch + if: always() + run: docker compose -f tools/docker/elasticsearch/docker-compose.yml down -v diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java index e4002be7c..b2c31d522 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java @@ -171,6 +171,10 @@ public static final class VectorStore { public static final String ELASTICSEARCH_VECTOR_STORE = "org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore"; + // Milvus + public static final String MILVUS_VECTOR_STORE = + "org.apache.flink.agents.integrations.vectorstores.milvus.MilvusVectorStore"; + // Python Wrapper public static final String PYTHON_WRAPPER_VECTOR_STORE = "org.apache.flink.agents.api.vectorstores.python.PythonVectorStore"; diff --git a/dist/pom.xml b/dist/pom.xml index f7e064f66..0274bc090 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -100,6 +100,11 @@ under the License. flink-agents-integrations-vector-stores-elasticsearch ${project.version} + + org.apache.flink + flink-agents-integrations-vector-stores-milvus + ${project.version} + org.apache.flink flink-agents-integrations-vector-stores-opensearch @@ -156,4 +161,4 @@ under the License. - \ No newline at end of file + diff --git a/docs/content/docs/development/vector_stores.md b/docs/content/docs/development/vector_stores.md index decc9c0d0..4efcbe0a3 100644 --- a/docs/content/docs/development/vector_stores.md +++ b/docs/content/docs/development/vector_stores.md @@ -172,7 +172,7 @@ For vector stores that implement `CollectionManageableVectorStore`, you can crea * `delete_collection` / `deleteCollection`: Delete a collection by name. {{< hint info >}} -Collection-level operations are only supported for vector stores that implement `CollectionManageableVectorStore`. Among the built-in providers, Chroma (Python), Elasticsearch (Java) and OpenSearch (Java) implement this interface. +Collection-level operations are only supported for vector stores that implement `CollectionManageableVectorStore`. Among the built-in providers, Chroma (Python), Elasticsearch (Java), OpenSearch (Java), and Milvus (Java) implement this interface. {{< /hint >}} {{< tabs "Collection level operations" >}} @@ -642,9 +642,86 @@ public static ResourceDescriptor vectorStore() { {{< /tabs >}} +### Milvus + +[Milvus](https://milvus.io/) is an open-source vector database designed for high-dimensional vector search at scale. + +{{< hint info >}} +Milvus is currently supported in the Java API only. To use Milvus from Python agents, see [Using Cross-Language Providers](#using-cross-language-providers). +{{< /hint >}} + +#### Prerequisites + +1. A Milvus server. + +#### MilvusVectorStore Parameters + +| Parameter | Type | Default | Description | +|-----------------------------|------|--------------------------------------|-----------------------------------------------------------------------------| +| `embedding_model` | str | Required | Reference to embedding model resource name | +| `collection` | str | `"flink_agents_milvus_collection"` | Default target Milvus collection name | +| `collection_name` | str | None | Alias for `collection` | +| `index` | str | None | Alias for `collection`, mainly for cross-provider compatibility | +| `id_field` | str | `"id"` | Name of the primary key field | +| `content_field` | str | `"content"` | Name of the field storing document content | +| `metadata_field` | str | `"metadata"` | Name of the JSON field storing document metadata | +| `vector_field` | str | `"embedding"` | Name of the FloatVector field used for vector search | +| `dims` | int | `768` | Vector dimensionality | +| `id_max_length` | int | `65535` | Maximum length for the VarChar primary key field | +| `content_max_length` | int | `65535` | Maximum length for the VarChar content field | +| `metric_type` | str | `"COSINE"` | Milvus metric type used by vector search | +| `index_type` | str | `"AUTOINDEX"` | Milvus vector index type | +| `index_params` | map | `{}` | Extra vector index parameters passed to Milvus | +| `metadata_index_keys` | list | `user_id`, `agent_id`, `run_id`, `actor_id`, `category` | Additional metadata JSON keys indexed with path indexes | +| `metadata_index_cast_types` | map | Default keys use `"VARCHAR"` | Per-metadata-key JSON path index cast type overrides | +| `num_shards` | int | `1` | Number of Milvus shards for newly created collections | +| `consistency_level` | str | `"BOUNDED"` | Milvus consistency level for collection creation, query, and search | +| `max_get_limit` | int | `10000` | Maximum number of documents returned by `get` when no limit is specified | +| `load_timeout_ms` | long | `120000` | Timeout for loading collections | +| `uri` | str | `"http://localhost:19530"` | Milvus endpoint | +| `host` | str | `"localhost"` | Milvus host used when `uri` is not set | +| `port` | int | `19530` | Milvus port used when `uri` is not set | +| `db_name` | str | None | Milvus database name | +| `token` | str | None | Token for Milvus authentication | +| `username` | str | None | Username for basic authentication | +| `password` | str | None | Password for basic authentication | +| `enable_precheck` | bool | `false` | Whether to enable Milvus client precheck | + +{{< hint info >}} +When creating a collection, MilvusVectorStore creates a primary-key field, content field, JSON metadata field, vector field, vector index, and JSON metadata indexes. The default metadata JSON path indexes cover common filter keys such as `user_id`, `agent_id`, `run_id`, `actor_id`, and `category`; add `metadata_index_keys` for application-specific filter keys. + +The default shard count is `1`. As a rough capacity-planning rule, use about one shard per 100 million vectors, and increase it for heavier write throughput. +{{< /hint >}} + +#### Usage Example + +{{< tabs "Milvus Usage Example" >}} + +{{< tab "Java" >}} + +```java +@VectorStore +public static ResourceDescriptor vectorStore() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.VectorStore.MILVUS_VECTOR_STORE) + .addInitialArgument("embedding_model", "embeddingModel") + .addInitialArgument("uri", "http://localhost:19530") + .addInitialArgument("collection", "my_documents") + .addInitialArgument("dims", 1536) + .addInitialArgument("metric_type", "COSINE") + .addInitialArgument("index_type", "AUTOINDEX") + // Optional metadata JSON path indexes + // .addInitialArgument("metadata_index_keys", List.of("user_id", "agent_id", "run_id")) + .build(); +} +``` + +{{< /tab >}} + +{{< /tabs >}} + ## Using Cross-Language Providers -Flink Agents supports cross-language vector store integration, allowing you to use vector stores implemented in one language (Java or Python) from agents written in the other language. This is particularly useful when a vector store provider is only available in one language (e.g., Elasticsearch is currently Java-only, Chroma is currently Python-only). +Flink Agents supports cross-language vector store integration, allowing you to use vector stores implemented in one language (Java or Python) from agents written in the other language. This is particularly useful when a vector store provider is only available in one language (e.g., Elasticsearch and Milvus are currently Java-only, Chroma is currently Python-only). {{< hint warning >}} **Limitations:** @@ -1101,4 +1178,4 @@ public class MyVectorStore extends BaseVectorStore {{< /tab >}} -{{< /tabs >}} \ No newline at end of file +{{< /tabs >}} diff --git a/docs/content/docs/faq/faq.md b/docs/content/docs/faq/faq.md index 6931ec3ca..0f3b47861 100644 --- a/docs/content/docs/faq/faq.md +++ b/docs/content/docs/faq/faq.md @@ -117,6 +117,7 @@ Flink Agents provides built-in integrations for many ecosystem providers. Some i |---|---|---| | [Chroma]({{< ref "docs/development/vector_stores#chroma" >}}) | ✅ | ❌ | | [Elasticsearch]({{< ref "docs/development/vector_stores#elasticsearch" >}}) | ❌ | ✅ | +| [Milvus]({{< ref "docs/development/vector_stores#milvus" >}}) | ❌ | ✅ | **MCP Server** @@ -131,4 +132,4 @@ Flink Agents provides built-in integrations for many ecosystem providers. Some i To avoid potential conflict with Flink cluster, the scope of the dependencies related to Flink and Flink Agents for agent job are provided. See [Maven Dependencies]({{< ref "docs/get-started/installation#maven-dependencies-for-java" >}}) for details. To run the examples in IDE, users must enable the IDE feature: `add dependencies with provided scope to classpath`. -* For **IDEA**, edit the **`Run/Debug Configuration`** and enable **`add dependencies with provided scope to classpath`**. See [Run/Debug Configuration](https://www.jetbrains.com/help/idea/run-debug-configuration-scala.html) for details. \ No newline at end of file +* For **IDEA**, edit the **`Run/Debug Configuration`** and enable **`add dependencies with provided scope to classpath`**. See [Run/Debug Configuration](https://www.jetbrains.com/help/idea/run-debug-configuration-scala.html) for details. diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml index 2d19a8e1b..8b2f02429 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml @@ -35,7 +35,7 @@ flink-agents-integrations-embedding-models-ollama ${project.version} - + org.apache.flink flink-agents-integrations-chat-models-openai @@ -46,6 +46,11 @@ flink-agents-integrations-vector-stores-elasticsearch ${project.version} + + org.apache.flink + flink-agents-integrations-vector-stores-milvus + ${project.version} + org.apache.flink flink-streaming-java @@ -67,4 +72,4 @@ ${flink.version} - \ No newline at end of file + diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java index d68bb2588..9a7655b61 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java @@ -52,19 +52,20 @@ * full retrieved item set. * *

All resources are declared as native Java implementations (Ollama chat / embedding, - * Elasticsearch vector store). Python's mem0 adapter consumes them through the cross-language - * bridge: {@code ctx.get_resource(name, type)} on the Python side returns a Java*Impl wrapper that - * delegates back into Java via Pemja. + * Elasticsearch or Milvus vector store). Python's mem0 adapter consumes them through the + * cross-language bridge: {@code ctx.get_resource(name, type)} on the Python side returns a + * Java*Impl wrapper that delegates back into Java via Pemja. * - *

The test driving this agent must (1) pull the Ollama models and (2) provide ES connection env - * vars ({@code ES_HOST}, {@code ES_INDEX}, {@code ES_DIMS}, {@code ES_VECTOR_FIELD}, optional - * {@code ES_USERNAME}/{@code ES_PASSWORD}); see {@link Mem0LongTermMemoryTest}. + *

The test driving this agent must (1) pull the Ollama models and (2) provide vector-store + * connection env vars; see {@link Mem0LongTermMemoryTest}. */ public class Mem0LongTermMemoryAgent extends Agent { public static final String CHAT_MODEL = "qwen3.6-plus"; public static final String OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; public static final String MEMORY_SET_NAME = "test_ltm"; + public static final String ES_LTM_STORE = "esLtmStore"; + public static final String MILVUS_LTM_STORE = "milvusLtmStore"; /** Mirrors the Python e2e: dashscope-hosted OpenAI-compatible endpoint, env-overridable. */ private static final String DEFAULT_BASE_URL = "https://coding.dashscope.aliyuncs.com/v1"; @@ -139,7 +140,9 @@ public static ResourceDescriptor esLtmStore() { ResourceDescriptor.Builder.newBuilder( ResourceName.VectorStore.ELASTICSEARCH_VECTOR_STORE) .addInitialArgument("embedding_model", "ollamaNomicEmbedText") - .addInitialArgument("host", System.getenv("ES_HOST")) + .addInitialArgument( + "host", + System.getenv().getOrDefault("ES_HOST", "http://localhost:9200")) .addInitialArgument( "collection", UUID.randomUUID().toString().substring(0, 8) + "-context"); @@ -152,6 +155,23 @@ public static ResourceDescriptor esLtmStore() { return builder.build(); } + @VectorStore + public static ResourceDescriptor milvusLtmStore() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.VectorStore.MILVUS_VECTOR_STORE) + .addInitialArgument("embedding_model", "ollamaNomicEmbedText") + .addInitialArgument( + "uri", System.getenv().getOrDefault("MILVUS_URI", "http://localhost:19530")) + .addInitialArgument( + "collection", + "flink_agents_mem0_" + UUID.randomUUID().toString().replace("-", "")) + .addInitialArgument("dims", 768) + // Test-only: Mem0 e2e reads immediately after writes. Production should use the + // default BOUNDED consistency unless immediate read-after-write visibility is + // required. + .addInitialArgument("consistency_level", "STRONG") + .build(); + } + @Action(listenEventTypes = {InputEvent.EVENT_TYPE}) public static void addItems(Event event, RunnerContext ctx) throws Exception { InputEvent inputEvent = InputEvent.fromEvent(event); diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java index d11b7dfe9..166db7adb 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java @@ -28,7 +28,8 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; import java.time.Instant; @@ -38,6 +39,8 @@ import java.util.Map; import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.Mem0LongTermMemoryAgent.ES_LTM_STORE; +import static org.apache.flink.agents.resource.test.Mem0LongTermMemoryAgent.MILVUS_LTM_STORE; import static org.apache.flink.agents.resource.test.Mem0LongTermMemoryAgent.OLLAMA_EMBEDDING_MODEL; /** @@ -54,7 +57,8 @@ *

  • {@code ACTION_API_KEY} env var (and optionally {@code ACTION_BASE_URL}) for the * OpenAI-compatible chat model — mirrors the Python e2e test's setup *
  • {@code python} on PATH with {@code mem0ai} and {@code flink_agents} installed - *
  • Elasticsearch reachable via the {@code ES_HOST} env var + *
  • Elasticsearch reachable via the {@code ES_HOST} env var, or Milvus reachable via the {@code + * MILVUS_URI} env var * */ public class Mem0LongTermMemoryTest { @@ -62,25 +66,29 @@ public class Mem0LongTermMemoryTest { private final boolean embeddingReady; private final boolean pythonReady; private final boolean esConfigured; + private final boolean milvusConfigured; private final boolean apiKeySet; public Mem0LongTermMemoryTest() throws IOException { embeddingReady = pullModel(OLLAMA_EMBEDDING_MODEL); pythonReady = isPythonAvailable(); esConfigured = System.getenv("ES_HOST") != null; + milvusConfigured = System.getenv("MILVUS_URI") != null; apiKeySet = System.getenv("ACTION_API_KEY") != null; } - @Test + @ParameterizedTest(name = "vectorStore={0}") + @ValueSource(strings = {ES_LTM_STORE, MILVUS_LTM_STORE}) @Disabled("Using mem0 in java depends on the pemja fix.") - public void testMem0LongTermMemory() throws Exception { + public void testMem0LongTermMemory(String vectorStore) throws Exception { Assumptions.assumeTrue( embeddingReady, "Ollama is not reachable or the embedding model could not be pulled"); Assumptions.assumeTrue( pythonReady, "`python` executable not found on PATH; this test requires Python with mem0ai installed"); - Assumptions.assumeTrue(esConfigured, "Elasticsearch env var (ES_HOST) is not set"); + Assumptions.assumeTrue( + isVectorStoreConfigured(vectorStore), vectorStoreMissingMessage(vectorStore)); Assumptions.assumeTrue( apiKeySet, "ACTION_API_KEY env var is not set; required for the OpenAI-compatible chat model"); @@ -105,7 +113,7 @@ public void testMem0LongTermMemory() throws Exception { agentsEnv .getConfig() .set(LongTermMemoryOptions.Mem0.EMBEDDING_MODEL_SETUP, "ollamaNomicEmbedText"); - agentsEnv.getConfig().set(LongTermMemoryOptions.Mem0.VECTOR_STORE, "esLtmStore"); + agentsEnv.getConfig().set(LongTermMemoryOptions.Mem0.VECTOR_STORE, vectorStore); DataStream outputStream = agentsEnv @@ -129,6 +137,26 @@ private static boolean isPythonAvailable() { } } + private boolean isVectorStoreConfigured(String vectorStore) { + if (ES_LTM_STORE.equals(vectorStore)) { + return esConfigured; + } + if (MILVUS_LTM_STORE.equals(vectorStore)) { + return milvusConfigured; + } + throw new IllegalArgumentException("Unknown vector store: " + vectorStore); + } + + private static String vectorStoreMissingMessage(String vectorStore) { + if (ES_LTM_STORE.equals(vectorStore)) { + return "Elasticsearch env var (ES_HOST) is not set"; + } + if (MILVUS_LTM_STORE.equals(vectorStore)) { + return "Milvus env var (MILVUS_URI) is not set"; + } + return "Unknown vector store: " + vectorStore; + } + @SuppressWarnings("unchecked") private void checkResult(CloseableIterator results) throws Exception { Map> records = new HashMap<>(); diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageAgent.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageAgent.java index e252fa082..cacf62444 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageAgent.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageAgent.java @@ -59,6 +59,7 @@ public class VectorStoreCrossLanguageAgent extends Agent { public static final String OLLAMA_MODEL = "nomic-embed-text"; public static final String TEST_COLLECTION = "test_collection"; + private static final String VECTOR_STORE_BACKEND = "CHROMA"; @EmbeddingModelConnection public static ResourceDescriptor embeddingConnection() { @@ -121,7 +122,8 @@ public static void inputEvent(Event event, RunnerContext ctx) throws Exception { TEST_COLLECTION, Map.of("metadata", Map.of("key1", "value1", "key2", "value2"))); - System.out.println("[TEST] Vector store Collection Management PASSED"); + System.out.printf( + "[TEST][%s] Vector store Collection Management PASSED%n", VECTOR_STORE_BACKEND); vectorStore.deleteCollection(TEST_COLLECTION); Assertions.assertThrows( @@ -168,7 +170,8 @@ public static void inputEvent(Event event, RunnerContext ctx) throws Exception { Assertions.assertEquals( Map.of("category", "database", "source", "test"), doc.getMetadata()); - System.out.println("[TEST] Vector store Document Management PASSED"); + System.out.printf( + "[TEST][%s] Vector store Document Management PASSED%n", VECTOR_STORE_BACKEND); // Verify VectorStoreQuery.filters survives the Java->Python bridge. // ChromaDB applies the unified-DSL filter to its `where` clause, so the @@ -191,7 +194,8 @@ public static void inputEvent(Event event, RunnerContext ctx) throws Exception { filteredDocs.get(0).getId(), "Filter {category=database} should match doc2"); - System.out.println("[TEST] Vector store filter query PASSED"); + System.out.printf( + "[TEST][%s] Vector store filter query PASSED%n", VECTOR_STORE_BACKEND); ctx.getShortTermMemory().set("is_initialized", true); } @@ -244,12 +248,16 @@ public static void contextRetrievalResponseEvent(Event event, RunnerContext ctx) first.getContent().substring(0, Math.min(50, first.getContent().length()))); ctx.sendEvent(new OutputEvent(result)); - System.out.printf("[TEST] Vector store retrieval PASSED, count=%d%n", documents.size()); + System.out.printf( + "[TEST][%s] Vector store retrieval PASSED, count=%d%n", + VECTOR_STORE_BACKEND, documents.size()); } catch (Exception e) { result.put("test_status", "FAILED"); result.put("error", e.getMessage()); ctx.sendEvent(new OutputEvent(result)); - System.err.printf("[TEST] Vector store retrieval FAILED: %s%n", e.getMessage()); + System.err.printf( + "[TEST][%s] Vector store retrieval FAILED: %s%n", + VECTOR_STORE_BACKEND, e.getMessage()); throw e; } } diff --git a/integrations/pom.xml b/integrations/pom.xml index 9989a5f01..754048813 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -33,6 +33,7 @@ under the License. 1.1.5 8.19.0 + 2.6.18 4.8.0 2.11.1 2.32.16 diff --git a/integrations/vector-stores/milvus/pom.xml b/integrations/vector-stores/milvus/pom.xml new file mode 100644 index 000000000..d65210215 --- /dev/null +++ b/integrations/vector-stores/milvus/pom.xml @@ -0,0 +1,52 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-vector-stores + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-vector-stores-milvus + Flink Agents : Integrations: Vector Stores: Milvus + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + io.milvus + milvus-sdk-java + ${milvus.version} + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + diff --git a/integrations/vector-stores/milvus/src/main/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStore.java b/integrations/vector-stores/milvus/src/main/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStore.java new file mode 100644 index 000000000..7eafcd351 --- /dev/null +++ b/integrations/vector-stores/milvus/src/main/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStore.java @@ -0,0 +1,1194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.vectorstores.milvus; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.GetLoadStateReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import io.milvus.v2.service.collection.request.LoadCollectionReq; +import io.milvus.v2.service.vector.request.DeleteReq; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.request.SearchReq; +import io.milvus.v2.service.vector.request.UpsertReq; +import io.milvus.v2.service.vector.request.data.BaseVector; +import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +/** + * Milvus-backed implementation of a vector store. + * + *

    This implementation executes dense-vector similarity search against a Milvus collection. It + * integrates with an embedding model (configured via the {@code embedding_model} resource argument + * inherited from {@link BaseVectorStore}) to convert query text into embeddings and then performs + * vector search using Milvus' search API. + * + *

    The store creates collections with a simple dense-vector schema: + * + *

      + *
    • {@code id}: VarChar primary key + *
    • {@code content}: VarChar document content + *
    • {@code metadata}: JSON metadata map + *
    • {@code embedding}: FloatVector + *
    + * + *

    Configuration is provided through {@link ResourceDescriptor} arguments. The most relevant ones + * are: + * + *

      + *
    • {@code collection} or {@code index} (optional): Target collection name. If omitted, + * defaults to {@link #DEFAULT_COLLECTION}. + *
    • {@code dims} (optional): Vector dimensionality; defaults to {@link #DEFAULT_DIMENSION}. + *
    • {@code vector_field}, {@code content_field}, {@code metadata_field}, {@code id_field} + * (optional): Schema field names. + *
    • {@code metric_type} (optional): Milvus metric type; defaults to {@code COSINE}. + *
    • {@code index_type} and {@code index_params} (optional): Milvus vector index settings. + *
    • {@code metadata_index_keys} (optional): Additional top-level metadata keys to index as JSON + * path indexes. The default keys are {@code user_id}, {@code agent_id}, {@code run_id}, + * {@code actor_id}, and {@code category}. + *
    • {@code metadata_index_cast_types} (optional): Map from metadata key to Milvus JSON index + * cast type. Defaults to {@code VARCHAR}; use values such as {@code DOUBLE} for numeric + * metadata keys. + *
    • {@code num_shards} (optional): Number of Milvus shards to create with the collection; + * defaults to {@link #DEFAULT_NUM_SHARDS}. As a rough capacity-planning rule, use about one + * shard per 100 million vectors, and increase it for heavier write throughput. + *
    • {@code consistency_level} (optional): Milvus consistency level for query and search; + * defaults to {@code BOUNDED}. Use {@code STRONG} when immediate read-after-write visibility + * is required. + *
    • {@code load_timeout_ms} (optional): Timeout used when loading a collection from {@link + * #createCollectionIfNotExists(String, Map)}; defaults to {@link #DEFAULT_LOAD_TIMEOUT_MS}. + *
    • {@code uri}, or {@code host}/{@code port} (optional): Milvus endpoint. If omitted, defaults + * to {@code http://localhost:19530}. + *
    • Authentication (optional): Either token auth via {@code token}, or basic auth via {@code + * username}/{@code password}. + *
    + * + *

    Example usage: + * + *

    {@code
    + * ResourceDescriptor desc = ResourceDescriptor.Builder
    + *     .newBuilder(MilvusVectorStore.class.getName())
    + *     .addInitialArgument("embedding_model", "textEmbedder")
    + *     .addInitialArgument("uri", "http://localhost:19530")
    + *     .addInitialArgument("collection", "my_documents")
    + *     .addInitialArgument("dims", 768)
    + *     .addInitialArgument("metric_type", "COSINE")
    + *     .addInitialArgument("index_type", "AUTOINDEX")
    + *     .build();
    + * }
    + */ +public class MilvusVectorStore extends BaseVectorStore implements CollectionManageableVectorStore { + + /** + * Default collection name used when {@code collection}, {@code collection_name}, and {@code + * index} are omitted. + */ + public static final String DEFAULT_COLLECTION = "flink_agents_milvus_collection"; + /** Default primary key field name. */ + public static final String DEFAULT_ID_FIELD = "id"; + /** Default field name used to store document content. */ + public static final String DEFAULT_CONTENT_FIELD = "content"; + /** Default JSON field name used to store document metadata. */ + public static final String DEFAULT_METADATA_FIELD = "metadata"; + /** Default FloatVector field name on which Milvus search is executed. */ + public static final String DEFAULT_VECTOR_FIELD = "embedding"; + /** Default index name for the full metadata JSON index. */ + public static final String DEFAULT_METADATA_INDEX_NAME = "metadata_json_index"; + /** Metadata keys that are commonly used by Mem0 and vector-store filter callers. */ + public static final List DEFAULT_METADATA_INDEX_KEYS = + List.of("user_id", "agent_id", "run_id", "actor_id", "category"); + /** Default Milvus JSON cast type used for metadata path indexes. */ + public static final String DEFAULT_METADATA_INDEX_CAST_TYPE = "VARCHAR"; + /** Default vector dimensionality used when {@code dims} is not provided. */ + public static final int DEFAULT_DIMENSION = 768; + /** The maximum number of documents that can be retrieved by get when limit is omitted. */ + public static final int DEFAULT_MAX_GET_LIMIT = 10000; + /** Default maximum length for the VarChar primary key field. */ + public static final int DEFAULT_ID_MAX_LENGTH = 65535; + /** Default maximum length for the VarChar content field. */ + public static final int DEFAULT_CONTENT_MAX_LENGTH = 65535; + /** Default number of Milvus shards used when creating a collection. */ + public static final int DEFAULT_NUM_SHARDS = 1; + /** Default timeout for synchronous collection load operations. */ + public static final long DEFAULT_LOAD_TIMEOUT_MS = 120000L; + + /** Milvus connection configuration built from the resource descriptor. */ + private final ConnectConfig connectConfig; + /** Lazily-created Milvus client used to execute collection and vector requests. */ + private transient volatile @Nullable MilvusClientV2 client; + + private final Gson gson = new Gson(); + + /** Resolved Milvus endpoint URI. */ + private final String uri; + /** Optional Milvus database name. */ + private final @Nullable String databaseName; + /** Default collection name used when a per-call collection is not supplied. */ + private final String defaultCollection; + /** Name of the primary key field. */ + private final String idField; + /** Name of the content field to store the document content. */ + private final String contentField; + /** Name of the JSON field to store document metadata. */ + private final String metadataField; + /** Name of the FloatVector field on which vector search is executed. */ + private final String vectorField; + /** Vector dimensionality of the {@link #vectorField}. */ + private final int dims; + /** Default query limit used by get when no limit is provided. */ + private final int maxGetLimit; + /** Maximum length for the VarChar primary key field. */ + private final int idMaxLength; + /** Maximum length for the VarChar content field. */ + private final int contentMaxLength; + /** Default Milvus metric type used for collection indexes and search. */ + private final IndexParam.MetricType metricType; + /** Default Milvus index type used when creating collections. */ + private final IndexParam.IndexType indexType; + /** Extra index parameters passed to Milvus collection creation. */ + private final Map indexParams; + /** Metadata JSON keys indexed with path-specific indexes during collection creation. */ + private final List metadataIndexKeys; + /** Per-metadata-key JSON cast type overrides for path-specific indexes. */ + private final Map metadataIndexCastTypes; + /** Number of shards used when creating collections. */ + private final int numShards; + /** Consistency level used for collection creation, query, and search requests. */ + private final ConsistencyLevel consistencyLevel; + /** Timeout used when loading collections from create-collection paths. */ + private final long loadTimeoutMs; + /** + * Creates a new {@code MilvusVectorStore} from the provided descriptor and resource resolver. + * + *

    The constructor reads connection, authentication, schema, index, and query defaults from + * the descriptor and prepares a {@link ConnectConfig}. The Milvus client itself is created + * lazily on first use. + * + * @param descriptor Resource descriptor containing configuration arguments + * @param resourceContext Context used to resolve other resources by name and type + */ + public MilvusVectorStore(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + this.uri = resolveUri(descriptor); + this.databaseName = stringArg(descriptor, "db_name", null); + this.defaultCollection = + stringArg( + descriptor, + "collection", + stringArg( + descriptor, + "collection_name", + stringArg(descriptor, "index", DEFAULT_COLLECTION))); + this.idField = stringArg(descriptor, "id_field", DEFAULT_ID_FIELD); + this.contentField = stringArg(descriptor, "content_field", DEFAULT_CONTENT_FIELD); + this.metadataField = stringArg(descriptor, "metadata_field", DEFAULT_METADATA_FIELD); + this.vectorField = stringArg(descriptor, "vector_field", DEFAULT_VECTOR_FIELD); + this.dims = intArg(descriptor, "dims", DEFAULT_DIMENSION); + this.maxGetLimit = intArg(descriptor, "max_get_limit", DEFAULT_MAX_GET_LIMIT); + this.idMaxLength = intArg(descriptor, "id_max_length", DEFAULT_ID_MAX_LENGTH); + this.contentMaxLength = + intArg(descriptor, "content_max_length", DEFAULT_CONTENT_MAX_LENGTH); + this.metricType = + enumArg( + IndexParam.MetricType.class, + stringArg(descriptor, "metric_type", IndexParam.MetricType.COSINE.name())); + this.indexType = + enumArg( + IndexParam.IndexType.class, + stringArg(descriptor, "index_type", IndexParam.IndexType.AUTOINDEX.name())); + this.indexParams = mapArg(descriptor, "index_params"); + this.metadataIndexCastTypes = metadataIndexCastTypesArg(descriptor); + this.metadataIndexKeys = metadataIndexKeysArg(descriptor, this.metadataIndexCastTypes); + this.numShards = intArg(descriptor, "num_shards", DEFAULT_NUM_SHARDS); + this.consistencyLevel = + enumArg( + ConsistencyLevel.class, + stringArg( + descriptor, "consistency_level", ConsistencyLevel.BOUNDED.name())); + this.loadTimeoutMs = longArg(descriptor, "load_timeout_ms", DEFAULT_LOAD_TIMEOUT_MS); + + ConnectConfig.ConnectConfigBuilder builder = + ConnectConfig.builder() + .uri(this.uri) + .secure(this.uri.startsWith("https://")) + .enablePrecheck(booleanArg(descriptor, "enable_precheck", false)); + + String token = stringArg(descriptor, "token", null); + if (token != null && !token.isEmpty()) { + builder.token(token); + } + String username = stringArg(descriptor, "username", null); + String password = stringArg(descriptor, "password", null); + if (username != null && password != null) { + builder.username(username).password(password); + } + if (this.databaseName != null) { + builder.dbName(this.databaseName); + } + + this.connectConfig = builder.build(); + } + + @Override + public void close() { + synchronized (this) { + if (this.client != null) { + this.client.close(); + this.client = null; + } + } + } + + /** + * Returns default store-level arguments collected from the descriptor. + * + *

    The returned map can be merged with per-query arguments to form the complete set of + * parameters for Milvus collection creation, retrieval, and vector search operations. + * + * @return map of default store arguments such as {@code uri}, {@code collection}, {@code + * vector_field}, {@code dims}, {@code metric_type}, {@code index_type}, and {@code + * num_shards}. + */ + @Override + public Map getStoreKwargs() { + Map kwargs = new HashMap<>(); + kwargs.put("uri", this.uri); + kwargs.put("collection", this.defaultCollection); + kwargs.put("index", this.defaultCollection); + kwargs.put("id_field", this.idField); + kwargs.put("content_field", this.contentField); + kwargs.put("metadata_field", this.metadataField); + kwargs.put("vector_field", this.vectorField); + kwargs.put("dims", this.dims); + kwargs.put("metric_type", this.metricType.name()); + kwargs.put("index_type", this.indexType.name()); + kwargs.put("index_params", new HashMap<>(this.indexParams)); + kwargs.put("metadata_index_keys", new ArrayList<>(this.metadataIndexKeys)); + kwargs.put("metadata_index_cast_types", new HashMap<>(this.metadataIndexCastTypes)); + kwargs.put("num_shards", this.numShards); + kwargs.put("consistency_level", this.consistencyLevel.name()); + kwargs.put("load_timeout_ms", this.loadTimeoutMs); + if (this.databaseName != null) { + kwargs.put("db_name", this.databaseName); + } + return kwargs; + } + + /** Returns the lazily-created Milvus client. */ + private MilvusClientV2 client() { + MilvusClientV2 current = this.client; + if (current == null) { + synchronized (this) { + current = this.client; + if (current == null) { + current = new MilvusClientV2(this.connectConfig); + this.client = current; + } + } + } + return current; + } + + /** + * Creates the Milvus collection for the given name if it does not already exist. + * + *

    The created schema contains a VarChar primary key, a VarChar content field, a nullable + * JSON metadata field, and one FloatVector field. Vector index settings can be supplied through + * {@code kwargs}; otherwise descriptor defaults are used. The collection is loaded before this + * method returns, including the case where it already existed. + */ + @Override + public void createCollectionIfNotExists(String name, Map kwargs) { + if (hasCollection(name)) { + ensureCollectionLoaded(name, kwargs); + return; + } + + int dimension = intFromMap(kwargs, "dims", this.dims); + IndexParam.MetricType metric = + enumFromMap(IndexParam.MetricType.class, kwargs, "metric_type", this.metricType); + IndexParam.IndexType index = + enumFromMap(IndexParam.IndexType.class, kwargs, "index_type", this.indexType); + int numShards = intFromMap(kwargs, "num_shards", this.numShards); + Map params = + kwargs.containsKey("index_params") + ? objectToMap(kwargs.get("index_params")) + : this.indexParams; + + CreateCollectionReq.CollectionSchema schema = client().createSchema(); + schema.setEnableDynamicField(false); + schema.addField( + AddFieldReq.builder() + .fieldName(this.idField) + .dataType(DataType.VarChar) + .isPrimaryKey(Boolean.TRUE) + .autoID(Boolean.FALSE) + .maxLength(intFromMap(kwargs, "id_max_length", this.idMaxLength)) + .build()); + schema.addField( + AddFieldReq.builder() + .fieldName(this.contentField) + .dataType(DataType.VarChar) + .maxLength(intFromMap(kwargs, "content_max_length", this.contentMaxLength)) + .build()); + schema.addField( + AddFieldReq.builder() + .fieldName(this.metadataField) + .dataType(DataType.JSON) + .isNullable(Boolean.TRUE) + .build()); + schema.addField( + AddFieldReq.builder() + .fieldName(this.vectorField) + .dataType(DataType.FloatVector) + .dimension(dimension) + .build()); + + IndexParam vectorIndex = + IndexParam.builder() + .fieldName(this.vectorField) + .indexType(index) + .metricType(metric) + .extraParams(params) + .build(); + List indexParams = new ArrayList<>(); + indexParams.add(vectorIndex); + indexParams.add(metadataJsonIndexParam()); + Map metadataCastTypes = metadataIndexCastTypesFromArgs(kwargs); + for (String key : metadataIndexKeysFromArgs(kwargs, metadataCastTypes)) { + indexParams.add( + metadataJsonPathIndexParam( + key, + metadataCastTypes.getOrDefault(key, DEFAULT_METADATA_INDEX_CAST_TYPE))); + } + + CreateCollectionReq.CreateCollectionReqBuilder builder = + CreateCollectionReq.builder() + .collectionName(name) + .collectionSchema(schema) + .consistencyLevel(this.consistencyLevel) + .numShards(numShards) + .indexParams(indexParams); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().createCollection(builder.build()); + ensureCollectionLoaded(name, kwargs); + } + + /** Deletes the Milvus collection with the given name. */ + @Override + public void deleteCollection(String name) { + DropCollectionReq.DropCollectionReqBuilder builder = + DropCollectionReq.builder().collectionName(name).async(Boolean.FALSE); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().dropCollection(builder.build()); + } + + /** + * Retrieve documents from the vector store. + * + *

    If {@code ids} are provided, this method queries by primary key and ignores {@code + * filters} and {@code limit} per the {@link BaseVectorStore} contract. Otherwise it queries + * with either the unified equality-only {@code filters} DSL or an all-rows fallback expression. + * + * @param ids The ids of the documents. + * @param collection The name of the collection to retrieve from. If null, retrieve from the + * default collection. + * @param filters Unified equality-only filter DSL matched against metadata JSON fields. + * @param limit Maximum number of documents to return; falls back to {@link + * #DEFAULT_MAX_GET_LIMIT} when null. + * @param extraArgs Additional Milvus-specific arguments. + * @return List of documents retrieved. + */ + @Override + public List get( + @Nullable List ids, + @Nullable String collection, + @Nullable Map filters, + @Nullable Integer limit, + Map extraArgs) + throws IOException { + String targetCollection = resolveCollection(collection); + + if (ids != null && !ids.isEmpty()) { + // Get specific documents by IDs; filters and limit are ignored per + // BaseVectorStore contract. + return getDocumentsByIds(targetCollection, ids); + } + + // Get all documents with optional filters and limit. + return getDocuments(targetCollection, filtersToExpression(filters), limit); + } + + /** + * Retrieves documents by their IDs using Milvus query API. + * + * @param collection The collection to query + * @param ids List of document IDs to retrieve + * @return List of Documents + */ + private List getDocumentsByIds(String collection, List ids) { + QueryReq.QueryReqBuilder builder = + QueryReq.builder() + .collectionName(collection) + .ids(toObjectIds(ids)) + .outputFields(outputFields()) + .consistencyLevel(this.consistencyLevel); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + QueryResp resp = client().query(builder.build()); + return queryResultsToDocuments(resp.getQueryResults()); + } + + /** + * Retrieves documents using Milvus query API with optional filters and limit. + * + * @param collection The collection to query + * @param filter Optional Milvus boolean expression + * @param limit Maximum number of documents to return + * @return List of Documents + */ + private List getDocuments( + String collection, @Nullable String filter, @Nullable Integer limit) { + QueryReq.QueryReqBuilder builder = + QueryReq.builder() + .collectionName(collection) + .outputFields(outputFields()) + .filter(filter == null ? allRowsFilter() : filter) + .consistencyLevel(this.consistencyLevel) + .limit(limit == null ? this.maxGetLimit : limit); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + QueryResp resp = client().query(builder.build()); + return queryResultsToDocuments(resp.getQueryResults()); + } + + /** + * Delete documents in the vector store. + * + *

    If ids are provided, this method deletes the corresponding primary keys. Otherwise it + * deletes documents matched by the unified equality-only {@code filters} DSL. If no filter is + * provided, it deletes all documents in the target collection. + * + * @param ids The ids of the documents. + * @param collection The name of the collection the documents belong to. If null, use the + * default collection. + * @param filters Unified equality-only filter DSL matched against metadata JSON fields. + * @param extraArgs Additional Milvus-specific arguments. + */ + @Override + public void delete( + @Nullable List ids, + @Nullable String collection, + @Nullable Map filters, + Map extraArgs) + throws IOException { + String targetCollection = resolveCollection(collection); + if (ids != null && !ids.isEmpty()) { + // Delete specific documents by IDs. + deleteDocumentsByIds(targetCollection, ids); + } else { + // Delete all documents with optional filters. + deleteDocuments(targetCollection, filters); + } + } + + /** + * Deletes documents by their IDs using Milvus delete API. + * + * @param collection The collection to delete from + * @param ids List of document IDs to delete + */ + private void deleteDocumentsByIds(String collection, List ids) { + DeleteReq.DeleteReqBuilder builder = + DeleteReq.builder().collectionName(collection).ids(toObjectIds(ids)); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().delete(builder.build()); + } + + /** + * Deletes documents using Milvus delete API with optional filters. + * + * @param collection The collection to delete from + * @param filters Unified equality-only filter DSL matched against metadata JSON fields + */ + private void deleteDocuments(String collection, @Nullable Map filters) { + String filter = filtersToExpression(filters); + DeleteReq.DeleteReqBuilder builder = + DeleteReq.builder() + .collectionName(collection) + .filter(filter == null ? allRowsFilter() : filter); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().delete(builder.build()); + } + + /** + * Executes a Milvus vector search using a pre-computed embedding. + * + *

    The method searches the configured vector field and returns only document id, content, and + * metadata as output fields. The returned {@link Document#getScore()} value is populated from + * the Milvus search result score, not from an output field. + * + * @param embedding The embedding vector to search with + * @param limit Maximum number of nearest neighbors to return + * @param collection The collection to search. If null, search the default collection. + * @param filters Unified equality-only filter DSL matched against metadata JSON fields. + * @param args Additional arguments. Supported keys include {@code metric_type} and {@code + * search_params}. + * @return A list of matching documents, possibly empty + */ + @Override + public List queryEmbedding( + float[] embedding, + int limit, + @Nullable String collection, + @Nullable Map filters, + Map args) { + String targetCollection = resolveCollection(collection); + SearchReq.SearchReqBuilder builder = + SearchReq.builder() + .collectionName(targetCollection) + .annsField(this.vectorField) + .metricType( + enumFromMap( + IndexParam.MetricType.class, + args, + "metric_type", + this.metricType)) + .data(Collections.singletonList(new FloatVec(embedding))) + .limit(limit) + .outputFields(outputFields()) + .consistencyLevel(this.consistencyLevel); + String filter = filtersToExpression(filters); + if (filter != null) { + builder.filter(filter); + } + Map searchParams = objectToMap(args.get("search_params")); + if (!searchParams.isEmpty()) { + builder.searchParams(searchParams); + } + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + + SearchResp resp = client().search(builder.build()); + List> groups = resp.getSearchResults(); + if (groups == null || groups.isEmpty()) { + return Collections.emptyList(); + } + return searchResultsToDocuments(groups.get(0)); + } + + /** + * Add documents with pre-computed embeddings to the vector store. + * + *

    Documents without ids get generated UUIDs. Add always uses Milvus insert; callers should + * use {@link #updateEmbedding(List, String, Map)} when they need to replace existing documents. + * Each document must already contain an embedding; the public {@link BaseVectorStore#add(List, + * String, Map)} path handles auto-embedding before it reaches this method. + * + * @return List of document ids written to Milvus + */ + @Override + public List addEmbedding( + List documents, @Nullable String collection, Map extraArgs) + throws IOException { + if (documents == null || documents.isEmpty()) { + return Collections.emptyList(); + } + String targetCollection = resolveCollection(collection); + + List rows = new ArrayList<>(); + List ids = new ArrayList<>(); + for (Document doc : documents) { + String id = doc.getId(); + if (id == null || id.isEmpty()) { + id = UUID.randomUUID().toString(); + } + ids.add(id); + rows.add(toRow(id, doc)); + } + insertRows(targetCollection, rows); + return ids; + } + + /** + * Update documents with pre-computed embeddings. + * + *

    Milvus upsert rewrites rows by primary key. The public {@link BaseVectorStore#update(List, + * String, Map)} path already enforces that every document carries an id. + */ + @Override + public void updateEmbedding( + List documents, @Nullable String collection, Map extraArgs) + throws IOException { + String targetCollection = resolveCollection(collection); + + List rows = new ArrayList<>(); + for (Document doc : documents) { + rows.add(toRow(Objects.requireNonNull(doc.getId()), doc)); + } + upsertRows(targetCollection, rows); + } + + /** Writes rows using Milvus insert. */ + private void insertRows(String targetCollection, List rows) { + InsertReq.InsertReqBuilder builder = + InsertReq.builder().collectionName(targetCollection).data(rows); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().insert(builder.build()); + } + + /** Writes rows using Milvus upsert, so repeated ids replace the existing entity. */ + private void upsertRows(String targetCollection, List rows) { + UpsertReq.UpsertReqBuilder builder = + UpsertReq.builder() + .collectionName(targetCollection) + .data(rows) + .partialUpdate(false); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + client().upsert(builder.build()); + } + + /** Converts a {@link Document} into the row object expected by Milvus insert/upsert APIs. */ + private JsonObject toRow(String id, Document doc) { + if (doc.getEmbedding() == null) { + throw new IllegalArgumentException("Document embedding must not be null."); + } + JsonObject row = new JsonObject(); + row.addProperty(this.idField, id); + row.addProperty(this.contentField, doc.getContent()); + row.add(this.metadataField, this.gson.toJsonTree(doc.getMetadata())); + row.add(this.vectorField, this.gson.toJsonTree(toFloatList(doc.getEmbedding()))); + return row; + } + + /** Checks whether a collection exists in the configured Milvus database. */ + private boolean hasCollection(String collectionName) { + HasCollectionReq.HasCollectionReqBuilder builder = + HasCollectionReq.builder().collectionName(collectionName); + if (this.databaseName != null) { + builder.databaseName(this.databaseName); + } + return client().hasCollection(builder.build()); + } + + /** Loads the collection when Milvus reports that it is not loaded yet. */ + private void ensureCollectionLoaded(String collectionName, Map extraArgs) { + GetLoadStateReq.GetLoadStateReqBuilder stateBuilder = + GetLoadStateReq.builder().collectionName(collectionName); + if (this.databaseName != null) { + stateBuilder.databaseName(this.databaseName); + } + if (Boolean.TRUE.equals(client().getLoadState(stateBuilder.build()))) { + return; + } + + LoadCollectionReq.LoadCollectionReqBuilder loadBuilder = + LoadCollectionReq.builder() + .collectionName(collectionName) + .sync(Boolean.TRUE) + .timeout(longFromMap(extraArgs, "load_timeout_ms", this.loadTimeoutMs)); + if (this.databaseName != null) { + loadBuilder.databaseName(this.databaseName); + } + client().loadCollection(loadBuilder.build()); + } + + /** + * Creates an index on the full metadata JSON object. + * + *

    Unlike per-key JSON path indexes, this uses {@code json_path=metadata} and {@code + * json_cast_type=JSON}. It is the closest Milvus equivalent to Elasticsearch's dynamic metadata + * object mapping: upper layers still pass unified filters by logical metadata key, and {@link + * #filtersToExpression(Map)} expands them to {@code metadata["key"]} predicates. + */ + private IndexParam metadataJsonIndexParam() { + return IndexParam.builder() + .fieldName(this.metadataField) + .indexName(DEFAULT_METADATA_INDEX_NAME) + .indexType(IndexParam.IndexType.AUTOINDEX) + .extraParams(Map.of("json_path", this.metadataField, "json_cast_type", "JSON")) + .build(); + } + + /** + * Creates an index on a high-value metadata JSON key. + * + *

    Upper layers pass filters as logical keys such as {@code user_id}; this method maps each + * key to the Milvus path expression {@code metadata["user_id"]}. String-like keys use {@code + * VARCHAR} by default, while callers can override cast types for numeric or boolean metadata + * keys with {@code metadata_index_cast_types}. + */ + private IndexParam metadataJsonPathIndexParam(String key, String castType) { + return IndexParam.builder() + .fieldName(this.metadataField) + .indexName(metadataJsonPathIndexName(key)) + .indexType(IndexParam.IndexType.AUTOINDEX) + .extraParams( + Map.of( + "json_path", + this.metadataField + "[\"" + key + "\"]", + "json_cast_type", + castType)) + .build(); + } + + /** Returns the deterministic index name used for a metadata JSON path index. */ + private String metadataJsonPathIndexName(String key) { + return this.metadataField + "_" + key + "_json_index"; + } + + /** + * Resolves the effective collection name. + * + *

    Precedence is: explicit method argument, then the descriptor default. The descriptor + * accepts the {@code index} alias to stay compatible with Elasticsearch-style configuration, + * but per-call target collection selection should use the dedicated method parameter. + */ + private String resolveCollection(@Nullable String collectionName) { + if (collectionName != null) { + return collectionName; + } + return this.defaultCollection; + } + + /** + * Output fields requested from Milvus for retrieval/search results. + * + *

    The vector field is intentionally omitted: upper layers only require id, content, + * metadata, and the search score supplied separately by Milvus search results. + */ + private List outputFields() { + return List.of(this.idField, this.contentField, this.metadataField); + } + + /** Converts Milvus query/get results into Flink Agents documents. */ + private List queryResultsToDocuments(@Nullable List results) { + if (results == null || results.isEmpty()) { + return Collections.emptyList(); + } + List docs = new ArrayList<>(); + for (QueryResp.QueryResult result : results) { + docs.add(entityToDocument(result.getEntity(), null, null)); + } + return docs; + } + + /** Converts Milvus search results into Flink Agents documents and preserves hit scores. */ + private List searchResultsToDocuments(List results) { + List docs = new ArrayList<>(); + for (SearchResp.SearchResult result : results) { + docs.add(entityToDocument(result.getEntity(), result.getId(), result.getScore())); + } + return docs; + } + + /** + * Converts a Milvus entity map into a {@link Document}. + * + *

    Embeddings are not reconstructed from result rows because the vector field is not + * requested in {@link #outputFields()}. Search scores are carried separately from Milvus search + * results. + */ + private Document entityToDocument( + Map entity, @Nullable Object resultId, @Nullable Float score) { + Object idValue = resultId == null ? entity.get(this.idField) : resultId; + String id = idValue == null ? null : String.valueOf(idValue); + Object contentValue = entity.get(this.contentField); + String content = contentValue == null ? "" : String.valueOf(contentValue); + Map metadata = objectToMap(entity.get(this.metadataField)); + return new Document(content, metadata, id, null, score); + } + + /** Converts known map-like values returned by the Milvus SDK into a Java map. */ + @SuppressWarnings("unchecked") + private Map objectToMap(@Nullable Object value) { + if (value == null) { + return Collections.emptyMap(); + } + if (value instanceof Map) { + return new LinkedHashMap<>((Map) value); + } + if (value instanceof JsonObject) { + return this.gson.fromJson((JsonObject) value, Map.class); + } + if (value instanceof JsonElement) { + return this.gson.fromJson((JsonElement) value, Map.class); + } + return Collections.emptyMap(); + } + + /** Resolves metadata JSON path index cast types from descriptor defaults plus per-call args. */ + private Map metadataIndexCastTypesFromArgs(Map args) { + Map castTypes = new LinkedHashMap<>(this.metadataIndexCastTypes); + if (args.containsKey("metadata_index_cast_types")) { + putMetadataIndexCastTypes( + castTypes, objectToMap(args.get("metadata_index_cast_types"))); + } + return castTypes; + } + + /** Resolves metadata keys to index from defaults plus descriptor and per-call args. */ + private List metadataIndexKeysFromArgs( + Map args, Map castTypes) { + LinkedHashMap keys = new LinkedHashMap<>(); + putMetadataIndexKeys(keys, this.metadataIndexKeys); + putMetadataIndexKeys(keys, stringList(args.get("metadata_index_keys"))); + putMetadataIndexKeys(keys, castTypes.keySet()); + return new ArrayList<>(keys.keySet()); + } + + /** + * Translates the unified equality-only filter DSL into a Milvus boolean expression. + * + *

    Metadata is stored in a JSON field, so equality filters are translated into JSON subscript + * predicates such as {@code metadata["user_id"] == "alice"}. + */ + private String filtersToExpression(@Nullable Map filters) { + if (filters == null || filters.isEmpty()) { + return null; + } + List clauses = new ArrayList<>(); + for (Map.Entry entry : filters.entrySet()) { + Object value = entry.getValue(); + if (value == null || value instanceof Map) { + throw new UnsupportedOperationException( + "MilvusVectorStore filters support equality shorthand only."); + } + clauses.add( + this.metadataField + + "[\"" + + escapeString(entry.getKey()) + + "\"] == " + + literal(value)); + } + return String.join(" and ", clauses); + } + + /** Returns a Milvus expression that matches every row with a non-empty primary key. */ + private String allRowsFilter() { + return this.idField + " != \"\""; + } + + /** Formats a Java value as a Milvus expression literal. */ + private String literal(Object value) { + if (value instanceof Number || value instanceof Boolean) { + return String.valueOf(value); + } + return "\"" + escapeString(String.valueOf(value)) + "\""; + } + + /** Converts Java primitive float arrays into boxed lists accepted by the Milvus SDK. */ + private static List toFloatList(float[] embedding) { + List vector = new ArrayList<>(embedding.length); + for (float value : embedding) { + vector.add(value); + } + return vector; + } + + /** Converts string ids into the object list shape expected by Milvus get/delete requests. */ + private static List toObjectIds(List ids) { + return new ArrayList<>(ids); + } + + /** Resolves the Milvus endpoint from {@code uri}, {@code host}, and {@code port}. */ + private static String resolveUri(ResourceDescriptor descriptor) { + String uri = stringArg(descriptor, "uri", null); + if (uri != null && !uri.isEmpty()) { + return uri; + } + String host = stringArg(descriptor, "host", "localhost"); + int port = intArg(descriptor, "port", 19530); + if (host.startsWith("http://") || host.startsWith("https://")) { + return host; + } + if (host.contains(":")) { + return "http://" + host; + } + return "http://" + host + ":" + port; + } + + /** Reads a descriptor argument as a map, returning an empty map when absent. */ + @SuppressWarnings("unchecked") + private static Map mapArg(ResourceDescriptor descriptor, String key) { + Object value = descriptor.getArgument(key); + if (value instanceof Map) { + return new HashMap<>((Map) value); + } + return Collections.emptyMap(); + } + + /** Reads metadata JSON path index keys from the descriptor. */ + private static List metadataIndexKeysArg( + ResourceDescriptor descriptor, Map castTypes) { + LinkedHashMap keys = new LinkedHashMap<>(); + putMetadataIndexKeys(keys, DEFAULT_METADATA_INDEX_KEYS); + putMetadataIndexKeys(keys, stringList(descriptor.getArgument("metadata_index_keys"))); + putMetadataIndexKeys(keys, castTypes.keySet()); + return new ArrayList<>(keys.keySet()); + } + + /** Reads metadata JSON path index cast types from the descriptor. */ + private static Map metadataIndexCastTypesArg(ResourceDescriptor descriptor) { + Map castTypes = new LinkedHashMap<>(); + for (String key : DEFAULT_METADATA_INDEX_KEYS) { + castTypes.put(key, DEFAULT_METADATA_INDEX_CAST_TYPE); + } + putMetadataIndexCastTypes(castTypes, mapArg(descriptor, "metadata_index_cast_types")); + return castTypes; + } + + /** Adds metadata index keys while preserving insertion order and removing duplicates. */ + private static void putMetadataIndexKeys( + LinkedHashMap target, Iterable keys) { + for (String key : keys) { + target.put(normalizeMetadataIndexKey(key), Boolean.TRUE); + } + } + + /** Adds or overrides metadata index cast types. */ + private static void putMetadataIndexCastTypes( + Map target, Map castTypes) { + for (Map.Entry entry : castTypes.entrySet()) { + target.put( + normalizeMetadataIndexKey(entry.getKey()), + normalizeMetadataIndexCastType(entry.getValue())); + } + } + + /** Parses a descriptor or per-call argument as a string list. */ + private static List stringList(@Nullable Object value) { + if (value == null) { + return Collections.emptyList(); + } + if (value instanceof Iterable) { + List result = new ArrayList<>(); + for (Object item : (Iterable) value) { + if (item != null) { + result.add(String.valueOf(item)); + } + } + return result; + } + if (value instanceof Object[]) { + List result = new ArrayList<>(); + for (Object item : (Object[]) value) { + if (item != null) { + result.add(String.valueOf(item)); + } + } + return result; + } + String text = String.valueOf(value).trim(); + if (text.isEmpty()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(); + for (String part : text.split(",")) { + String trimmed = part.trim(); + if (!trimmed.isEmpty()) { + result.add(trimmed); + } + } + return result; + } + + /** + * Validates a top-level JSON key used for a generated Milvus path index. + * + *

    Milvus recommends JSON keys made from letters, digits, and underscores. Keeping the + * configurable index list to that subset also gives deterministic, valid index names. + */ + private static String normalizeMetadataIndexKey(String key) { + String trimmed = key == null ? "" : key.trim(); + if (trimmed.isEmpty()) { + throw new IllegalArgumentException("metadata_index_keys cannot contain empty keys."); + } + if (!isIdentifierStart(trimmed.charAt(0))) { + throw new IllegalArgumentException( + "metadata_index_keys must contain only JSON-safe identifiers: " + key); + } + for (int i = 1; i < trimmed.length(); i++) { + if (!isIdentifierPart(trimmed.charAt(i))) { + throw new IllegalArgumentException( + "metadata_index_keys must contain only JSON-safe identifiers: " + key); + } + } + return trimmed; + } + + /** Normalizes Milvus JSON index cast types. */ + private static String normalizeMetadataIndexCastType(@Nullable Object value) { + String type = + value == null + ? DEFAULT_METADATA_INDEX_CAST_TYPE + : String.valueOf(value).trim().toUpperCase(Locale.ROOT); + switch (type) { + case "BOOL": + case "DOUBLE": + case "VARCHAR": + case "ARRAY_BOOL": + case "ARRAY_DOUBLE": + case "ARRAY_VARCHAR": + case "JSON": + return type; + default: + throw new IllegalArgumentException( + "Unsupported Milvus metadata JSON index cast type: " + value); + } + } + + private static boolean isIdentifierStart(char value) { + return (value >= 'A' && value <= 'Z') || (value >= 'a' && value <= 'z') || value == '_'; + } + + private static boolean isIdentifierPart(char value) { + return isIdentifierStart(value) || (value >= '0' && value <= '9'); + } + + /** Reads a descriptor argument as a string, returning the supplied default when absent. */ + private static String stringArg( + ResourceDescriptor descriptor, String key, @Nullable String defaultValue) { + Object value = descriptor.getArgument(key); + return value == null ? defaultValue : String.valueOf(value); + } + + /** Reads a descriptor argument as an integer, accepting both numeric and string values. */ + private static int intArg(ResourceDescriptor descriptor, String key, int defaultValue) { + Object value = descriptor.getArgument(key); + if (value instanceof Number) { + return ((Number) value).intValue(); + } + if (value != null) { + return Integer.parseInt(String.valueOf(value)); + } + return defaultValue; + } + + /** Reads a descriptor argument as a long, accepting both numeric and string values. */ + private static long longArg(ResourceDescriptor descriptor, String key, long defaultValue) { + Object value = descriptor.getArgument(key); + if (value instanceof Number) { + return ((Number) value).longValue(); + } + if (value != null) { + return Long.parseLong(String.valueOf(value)); + } + return defaultValue; + } + + /** Reads a per-call argument as an integer, accepting both numeric and string values. */ + private static int intFromMap(Map args, String key, int defaultValue) { + Object value = args.get(key); + if (value instanceof Number) { + return ((Number) value).intValue(); + } + if (value != null) { + return Integer.parseInt(String.valueOf(value)); + } + return defaultValue; + } + + /** Reads a per-call argument as a long, accepting both numeric and string values. */ + private static long longFromMap(Map args, String key, long defaultValue) { + Object value = args.get(key); + if (value instanceof Number) { + return ((Number) value).longValue(); + } + if (value != null) { + return Long.parseLong(String.valueOf(value)); + } + return defaultValue; + } + + /** Reads a descriptor argument as a boolean. */ + private static boolean booleanArg( + ResourceDescriptor descriptor, String key, boolean defaultValue) { + Object value = descriptor.getArgument(key); + if (value instanceof Boolean) { + return (Boolean) value; + } + if (value != null) { + return Boolean.parseBoolean(String.valueOf(value)); + } + return defaultValue; + } + + /** Reads a per-call enum argument, returning the default when absent. */ + private static > E enumFromMap( + Class enumClass, Map args, String key, E defaultValue) { + Object value = args.get(key); + return value == null ? defaultValue : enumArg(enumClass, String.valueOf(value)); + } + + /** Parses enum names case-insensitively and accepts dash-separated names. */ + private static > E enumArg(Class enumClass, String value) { + return Enum.valueOf(enumClass, value.trim().replace('-', '_').toUpperCase(Locale.ROOT)); + } + + /** Escapes string content embedded in Milvus expression literals. */ + private static String escapeString(String value) { + return value.replace("\\", "\\\\").replace("\"", "\\\""); + } +} diff --git a/integrations/vector-stores/milvus/src/test/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStoreTest.java b/integrations/vector-stores/milvus/src/test/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStoreTest.java new file mode 100644 index 000000000..bfd44ceb7 --- /dev/null +++ b/integrations/vector-stores/milvus/src/test/java/org/apache/flink/agents/integrations/vectorstores/milvus/MilvusVectorStoreTest.java @@ -0,0 +1,642 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.vectorstores.milvus; + +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.DescribeCollectionReq; +import io.milvus.v2.service.collection.request.GetLoadStateReq; +import io.milvus.v2.service.collection.request.ReleaseCollectionReq; +import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.index.request.DescribeIndexReq; +import io.milvus.v2.service.index.request.ListIndexesReq; +import io.milvus.v2.service.index.response.DescribeIndexResp; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.mockito.Mockito; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link MilvusVectorStore}. */ +public class MilvusVectorStoreTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + @Test + void testConstructorAndStoreKwargs() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(MilvusVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embeddingModel") + .addInitialArgument("uri", "http://localhost:19530") + .addInitialArgument("collection", "test_collection") + .addInitialArgument("dims", 5) + .addInitialArgument("num_shards", 3) + .addInitialArgument("metric_type", "IP") + .addInitialArgument("index_type", "IVF_FLAT") + .addInitialArgument("metadata_index_keys", List.of("source")) + .addInitialArgument("metadata_index_cast_types", Map.of("score", "DOUBLE")) + // Test-only custom consistency value to verify descriptor plumbing. + // Production should use the default BOUNDED consistency unless immediate + // read-after-write visibility is required. + .addInitialArgument("consistency_level", "STRONG") + .addInitialArgument("load_timeout_ms", 12345L) + .build(); + + MilvusVectorStore store = new MilvusVectorStore(desc, NOOP); + Map kwargs = store.getStoreKwargs(); + assertThat(store).isInstanceOf(BaseVectorStore.class); + assertThat(store).isInstanceOf(CollectionManageableVectorStore.class); + assertThat(kwargs).containsEntry("collection", "test_collection"); + assertThat(kwargs).containsEntry("index", "test_collection"); + assertThat(kwargs).containsEntry("dims", 5); + assertThat(kwargs).containsEntry("num_shards", 3); + assertThat(kwargs).containsEntry("metric_type", "IP"); + assertThat(kwargs).containsEntry("index_type", "IVF_FLAT"); + assertThat((List) kwargs.get("metadata_index_keys")) + .containsExactly( + "user_id", "agent_id", "run_id", "actor_id", "category", "source", "score"); + assertThat((Map) kwargs.get("metadata_index_cast_types")) + .containsEntry("user_id", "VARCHAR") + .containsEntry("agent_id", "VARCHAR") + .containsEntry("run_id", "VARCHAR") + .containsEntry("actor_id", "VARCHAR") + .containsEntry("category", "VARCHAR") + .containsEntry("score", "DOUBLE"); + assertThat(kwargs).containsEntry("consistency_level", "STRONG"); + assertThat(kwargs).containsEntry("load_timeout_ms", 12345L); + assertThat(kwargs).doesNotContainKey("flush_on_write"); + assertThat(kwargs).doesNotContainKey("auto_create_collection"); + assertThat(kwargs).doesNotContainKey("metadata_index_enabled"); + store.close(); + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testCreateCollectionUsesShardCount() throws Exception { + String collection = collectionName("shard_count"); + MilvusVectorStore store = openStore(collection); + MilvusClientV2 client = + new MilvusClientV2( + ConnectConfig.builder().uri(System.getenv("MILVUS_URI")).build()); + try { + store.createCollectionIfNotExists(collection, Map.of("num_shards", 2)); + + DescribeCollectionResp resp = + client.describeCollection( + DescribeCollectionReq.builder().collectionName(collection).build()); + Assertions.assertEquals(2, resp.getShardsNum()); + } finally { + store.deleteCollection(collection); + store.close(); + client.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testCreateCollectionUsesMetadataJsonIndex() throws Exception { + String collection = collectionName("metadata_json_index"); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(MilvusVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embeddingModel") + .addInitialArgument("uri", System.getenv("MILVUS_URI")) + .addInitialArgument("collection", collection) + .addInitialArgument("dims", 5) + .addInitialArgument("index_type", "AUTOINDEX") + .addInitialArgument("metric_type", "COSINE") + .addInitialArgument("metadata_index_keys", List.of("source")) + .addInitialArgument("metadata_index_cast_types", Map.of("score", "DOUBLE")) + .build(); + MilvusVectorStore store = + new MilvusVectorStore( + desc, ResourceContext.fromGetResource(MilvusVectorStoreTest::getResource)); + store.open(); + MilvusClientV2 client = + new MilvusClientV2( + ConnectConfig.builder().uri(System.getenv("MILVUS_URI")).build()); + try { + store.createCollectionIfNotExists(collection, Map.of()); + + List indexNames = + client.listIndexes( + ListIndexesReq.builder() + .collectionName(collection) + .fieldName(MilvusVectorStore.DEFAULT_METADATA_FIELD) + .build()); + assertThat(indexNames).contains(MilvusVectorStore.DEFAULT_METADATA_INDEX_NAME); + assertThat(indexNames) + .contains( + metadataPathIndexName("user_id"), + metadataPathIndexName("agent_id"), + metadataPathIndexName("run_id"), + metadataPathIndexName("actor_id"), + metadataPathIndexName("category"), + metadataPathIndexName("source"), + metadataPathIndexName("score")); + + DescribeIndexResp resp = + client.describeIndex( + DescribeIndexReq.builder() + .collectionName(collection) + .indexName(MilvusVectorStore.DEFAULT_METADATA_INDEX_NAME) + .build()); + DescribeIndexResp.IndexDesc scoreIndex = + resp.getIndexDescByIndexName(MilvusVectorStore.DEFAULT_METADATA_INDEX_NAME); + Assertions.assertEquals( + MilvusVectorStore.DEFAULT_METADATA_FIELD, scoreIndex.getFieldName()); + Assertions.assertEquals(IndexParam.IndexType.AUTOINDEX, scoreIndex.getIndexType()); + assertThat(scoreIndex.getExtraParams()) + .containsEntry("json_path", MilvusVectorStore.DEFAULT_METADATA_FIELD) + .containsEntry("json_cast_type", "JSON"); + + assertMetadataPathIndex(client, collection, "user_id", "VARCHAR"); + assertMetadataPathIndex(client, collection, "actor_id", "VARCHAR"); + assertMetadataPathIndex(client, collection, "category", "VARCHAR"); + assertMetadataPathIndex(client, collection, "source", "VARCHAR"); + assertMetadataPathIndex(client, collection, "score", "DOUBLE"); + } finally { + store.deleteCollection(collection); + store.close(); + client.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testCollectionManagement() throws Exception { + String collection = collectionName("collection_management"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + Assertions.assertTrue( + store.get(null, collection, null, 10, Collections.emptyMap()).isEmpty()); + + store.deleteCollection(collection); + + Assertions.assertThrows( + Exception.class, + () -> store.get(null, collection, null, 10, Collections.emptyMap())); + } finally { + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testCreateCollectionIfNotExistsLoadsReleasedCollection() throws Exception { + String collection = collectionName("load_released"); + MilvusVectorStore store = openStore(collection); + MilvusClientV2 client = + new MilvusClientV2( + ConnectConfig.builder().uri(System.getenv("MILVUS_URI")).build()); + + try { + createCollection(store, collection); + store.add( + List.of( + new Document( + "Milvus is a vector database", + Map.of("category", "database"), + "doc1")), + collection, + Collections.emptyMap()); + + client.releaseCollection( + ReleaseCollectionReq.builder() + .collectionName(collection) + .async(Boolean.FALSE) + .build()); + Assertions.assertFalse( + client.getLoadState( + GetLoadStateReq.builder().collectionName(collection).build())); + + createCollection(store, collection); + Assertions.assertTrue( + client.getLoadState( + GetLoadStateReq.builder().collectionName(collection).build())); + + List loaded = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, loaded.size()); + assertDocument( + loaded.get(0), + "doc1", + "Milvus is a vector database", + Map.of("category", "database")); + } finally { + store.deleteCollection(collection); + store.close(); + client.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testDocumentManagement() throws Exception { + String collection = collectionName("document_management"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + store.add( + List.of( + new Document( + "Milvus is a vector database", + Map.of("category", "database", "source", "test"), + "doc1"), + new Document( + "Apache Flink Agents is an AI framework", + Map.of("category", "ai-agent", "source", "test"), + "doc2")), + collection, + Collections.emptyMap()); + + List all = store.get(null, collection, null, 10, Collections.emptyMap()); + Assertions.assertEquals(2, all.size()); + assertDocument( + documentById(all, "doc1"), + "doc1", + "Milvus is a vector database", + Map.of("category", "database", "source", "test")); + assertDocument( + documentById(all, "doc2"), + "doc2", + "Apache Flink Agents is an AI framework", + Map.of("category", "ai-agent", "source", "test")); + + List byId = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, byId.size()); + assertDocument( + byId.get(0), + "doc1", + "Milvus is a vector database", + Map.of("category", "database", "source", "test")); + + store.delete(List.of("doc1"), collection, null, Collections.emptyMap()); + List remaining = + store.get(null, collection, null, 10, Collections.emptyMap()); + Assertions.assertEquals(1, remaining.size()); + assertDocument( + remaining.get(0), + "doc2", + "Apache Flink Agents is an AI framework", + Map.of("category", "ai-agent", "source", "test")); + + store.delete(null, collection, null, Collections.emptyMap()); + Assertions.assertTrue( + store.get(null, collection, null, 10, Collections.emptyMap()).isEmpty()); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testFiltersDsl() throws Exception { + String collection = collectionName("filters_dsl"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + store.add( + List.of( + new Document( + "Milvus is a vector database", + Map.of("category", "database", "user_id", "alice"), + "doc_alice"), + new Document( + "Apache Flink Agents is an AI framework", + Map.of("category", "ai-agent", "user_id", "bob"), + "doc_bob")), + collection, + Collections.emptyMap()); + + List aliceOnly = + store.get( + null, + collection, + Map.of("user_id", "alice"), + 10, + Collections.emptyMap()); + Assertions.assertEquals(1, aliceOnly.size()); + Assertions.assertEquals("doc_alice", aliceOnly.get(0).getId()); + + List aliceQueried = + store.queryEmbedding( + new float[] {1.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + 5, + collection, + Map.of("user_id", "alice"), + Collections.emptyMap()); + Assertions.assertFalse(aliceQueried.isEmpty()); + Assertions.assertTrue( + aliceQueried.stream().allMatch(d -> "doc_alice".equals(d.getId()))); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testAddGeneratesIdsForDocumentsWithoutIds() throws Exception { + String collection = collectionName("generated_ids"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + List ids = + store.add( + List.of( + new Document( + "Milvus is a vector database", + Map.of("category", "database"), + null)), + collection, + Collections.emptyMap()); + + Assertions.assertEquals(1, ids.size()); + Assertions.assertNotNull(ids.get(0)); + Assertions.assertFalse(ids.get(0).isEmpty()); + + List stored = + store.get(List.of(ids.get(0)), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, stored.size()); + assertDocument( + stored.get(0), + ids.get(0), + "Milvus is a vector database", + Map.of("category", "database")); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testAddPreservesCallerProvidedId() throws Exception { + String collection = collectionName("add_with_id"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + Document document = + new Document( + "Milvus is a vector database", Map.of("category", "database"), "doc1"); + List ids = store.add(List.of(document), collection, Collections.emptyMap()); + Assertions.assertEquals(List.of("doc1"), ids); + + List stored = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, stored.size()); + assertDocument( + stored.get(0), + "doc1", + "Milvus is a vector database", + Map.of("category", "database")); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testExtraArgsCollectionDoesNotOverrideTargetCollection() throws Exception { + String collection = collectionName("target_collection"); + String ignoredCollection = collectionName("ignored_collection"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + List ids = + store.add( + List.of( + new Document( + "Milvus is a vector database", + Map.of("category", "database"), + "doc1")), + null, + Map.of("collection", ignoredCollection)); + Assertions.assertEquals(List.of("doc1"), ids); + + List stored = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, stored.size()); + assertDocument( + stored.get(0), + "doc1", + "Milvus is a vector database", + Map.of("category", "database")); + } finally { + dropCollectionQuietly(store, collection); + dropCollectionQuietly(store, ignoredCollection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testUpdateOverwritesExistingDocument() throws Exception { + String collection = collectionName("update_overwrite"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + Document original = + new Document( + "Milvus is a vector database", Map.of("category", "database"), "doc1"); + store.add(List.of(original), collection, Collections.emptyMap()); + + Document rewritten = + new Document( + "Milvus stores dense vectors", Map.of("category", "updated"), "doc1"); + store.update(List.of(rewritten), collection, Collections.emptyMap()); + + List after = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, after.size()); + Assertions.assertEquals("Milvus stores dense vectors", after.get(0).getContent()); + Assertions.assertEquals("updated", after.get(0).getMetadata().get("category")); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "MILVUS_URI", matches = ".+") + void testQueryEmbeddingPopulatesScore() throws Exception { + String collection = collectionName("score_populated"); + MilvusVectorStore store = openStore(collection); + + try { + createCollection(store, collection); + + store.add( + List.of( + new Document( + "Milvus is a vector database", Map.of("src", "test"), "doc1"), + new Document( + "Apache Flink Agents is an AI framework", + Map.of("src", "test"), + "doc2")), + collection, + Collections.emptyMap()); + + VectorStoreQuery q = + new VectorStoreQuery( + "Milvus is a vector database", 5, collection, Collections.emptyMap()); + List hits = store.query(q).getDocuments(); + Assertions.assertFalse(hits.isEmpty()); + Assertions.assertTrue( + hits.stream().allMatch(d -> d.getScore() != null), + "Every Milvus search hit should carry a score"); + + List byId = + store.get(List.of("doc1"), collection, null, null, Collections.emptyMap()); + Assertions.assertEquals(1, byId.size()); + Assertions.assertNull(byId.get(0).getScore()); + } finally { + store.deleteCollection(collection); + store.close(); + } + } + + /** + * Builds a descriptor for integration tests. + * + *

    Test-only: STRONG consistency is used here so reads immediately see preceding writes + * within the same test method. Production should use the default BOUNDED consistency unless + * immediate read-after-write visibility is required. + */ + private static ResourceDescriptor descriptor(String collection) { + return ResourceDescriptor.Builder.newBuilder(MilvusVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embeddingModel") + .addInitialArgument("uri", System.getenv("MILVUS_URI")) + .addInitialArgument("collection", collection) + .addInitialArgument("dims", 5) + .addInitialArgument("index_type", "AUTOINDEX") + .addInitialArgument("metric_type", "COSINE") + // Test-only: avoid timing-sensitive assertions after insert/update. Production + // should use the default BOUNDED consistency unless immediate read-after-write + // visibility is required. + .addInitialArgument("consistency_level", "STRONG") + .build(); + } + + private static MilvusVectorStore openStore(String collection) throws Exception { + MilvusVectorStore store = + new MilvusVectorStore( + descriptor(collection), + ResourceContext.fromGetResource(MilvusVectorStoreTest::getResource)); + store.open(); + return store; + } + + private static void createCollection(MilvusVectorStore store, String collection) { + store.createCollectionIfNotExists(collection, Map.of()); + } + + private static String collectionName(String prefix) { + return "fa_milvus_" + prefix + "_" + UUID.randomUUID().toString().replace("-", ""); + } + + private static Document documentById(List documents, String id) { + return documents.stream() + .filter(d -> id.equals(d.getId())) + .findFirst() + .orElseThrow(() -> new AssertionError("Missing document " + id)); + } + + private static void assertDocument( + Document document, String id, String content, Map metadata) { + Assertions.assertEquals(id, document.getId()); + Assertions.assertEquals(content, document.getContent()); + Assertions.assertEquals(metadata, document.getMetadata()); + Assertions.assertNull(document.getScore()); + } + + private static void assertMetadataPathIndex( + MilvusClientV2 client, String collection, String key, String castType) { + DescribeIndexResp resp = + client.describeIndex( + DescribeIndexReq.builder() + .collectionName(collection) + .indexName(metadataPathIndexName(key)) + .build()); + DescribeIndexResp.IndexDesc index = + resp.getIndexDescByIndexName(metadataPathIndexName(key)); + Assertions.assertEquals(MilvusVectorStore.DEFAULT_METADATA_FIELD, index.getFieldName()); + Assertions.assertEquals(IndexParam.IndexType.AUTOINDEX, index.getIndexType()); + assertThat(index.getExtraParams()) + .containsEntry( + "json_path", MilvusVectorStore.DEFAULT_METADATA_FIELD + "[\"" + key + "\"]") + .containsEntry("json_cast_type", castType); + } + + private static String metadataPathIndexName(String key) { + return MilvusVectorStore.DEFAULT_METADATA_FIELD + "_" + key + "_json_index"; + } + + private static void dropCollectionQuietly(MilvusVectorStore store, String collection) { + try { + store.deleteCollection(collection); + } catch (Exception ignored) { + // Best-effort cleanup for negative-path assertions. + } + } + + private static Resource getResource(String name, ResourceType type) { + BaseEmbeddingModelSetup embeddingModel = Mockito.mock(BaseEmbeddingModelSetup.class); + Mockito.when(embeddingModel.embed("Milvus is a vector database")) + .thenReturn(new float[] {1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + Mockito.when(embeddingModel.embed("Milvus stores dense vectors")) + .thenReturn(new float[] {1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + Mockito.when(embeddingModel.embed("Apache Flink Agents is an AI framework")) + .thenReturn(new float[] {0.0f, 1.0f, 0.0f, 0.0f, 0.0f}); + return embeddingModel; + } +} diff --git a/integrations/vector-stores/pom.xml b/integrations/vector-stores/pom.xml index 4d4766d95..4e1655d00 100644 --- a/integrations/vector-stores/pom.xml +++ b/integrations/vector-stores/pom.xml @@ -32,8 +32,9 @@ under the License. elasticsearch + milvus opensearch s3vectors - \ No newline at end of file + diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index f8d7741a0..905025a22 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -331,5 +331,8 @@ class Java: # Elasticsearch ELASTICSEARCH_VECTOR_STORE = "org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore" + # Milvus + MILVUS_VECTOR_STORE = "org.apache.flink.agents.integrations.vectorstores.milvus.MilvusVectorStore" + # MCP resource names MCP_SERVER = "flink_agents.integrations.mcp.mcp.MCPServer" diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py index 715b23791..18671bea2 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py @@ -44,9 +44,42 @@ ) TEST_COLLECTION = "test_collection" +DEFAULT_COLLECTION = "my_documents" +EMBEDDING_MODEL_RESOURCE = "embedding_model" +VECTOR_STORE_RESOURCE = "vector_store" +BACKEND_ELASTICSEARCH = "ELASTICSEARCH" +BACKEND_MILVUS = "MILVUS" MAX_RETRIES_TIMES = 10 +def _selected_backend() -> str: + return os.environ.get("VECTOR_STORE_BACKEND", BACKEND_ELASTICSEARCH).upper() + + +def _vector_store_backend_from_resource( + vector_store: CollectionManageableVectorStore, +) -> str: + # Java wrapper: reflect on the underlying Java class name + j_resource = getattr(vector_store, "_j_resource", None) + if j_resource is not None: + try: + class_name = j_resource.getClass().getName().lower() + if "milvus" in class_name: + return BACKEND_MILVUS + if "elasticsearch" in class_name: + return BACKEND_ELASTICSEARCH + except Exception: + pass + + # Pure Python store: fallback to env var (cross-language test only + # uses Java wrappers, so this path should not be hit in practice) + return _selected_backend() + + +def _backend_from_context(ctx: RunnerContext) -> str: + return ctx.short_term_memory.get("vector_store_backend") or _selected_backend() + + class VectorStoreCrossLanguageAgent(Agent): """Example agent demonstrating cross-language embedding model testing.""" @@ -88,14 +121,37 @@ def embedding_model() -> ResourceDescriptor: @staticmethod def vector_store() -> ResourceDescriptor: """Vector store setup for knowledge base.""" - return ResourceDescriptor( - clazz=ResourceName.VectorStore.JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE, - java_clazz=ResourceName.VectorStore.Java.ELASTICSEARCH_VECTOR_STORE, - embedding_model="embedding_model", - host=os.environ.get("ES_HOST"), - index="my_documents", - dims=768, - ) + backend = _selected_backend() + collection = os.environ.get("VECTOR_STORE_COLLECTION", DEFAULT_COLLECTION) + + if backend == BACKEND_ELASTICSEARCH: + return ResourceDescriptor( + clazz=ResourceName.VectorStore.JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE, + java_clazz=ResourceName.VectorStore.Java.ELASTICSEARCH_VECTOR_STORE, + embedding_model=EMBEDDING_MODEL_RESOURCE, + host=os.environ.get("ES_HOST"), + index=collection, + dims=768, + ) + if backend == BACKEND_MILVUS: + return ResourceDescriptor( + clazz=ResourceName.VectorStore.JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE, + java_clazz=ResourceName.VectorStore.Java.MILVUS_VECTOR_STORE, + embedding_model=EMBEDDING_MODEL_RESOURCE, + uri=os.environ.get("MILVUS_URI"), + collection=collection, + dims=768, + metric_type="COSINE", + index_type="AUTOINDEX", + # Test-only: this e2e checks read-after-write behavior immediately. + # Production should use the default BOUNDED consistency unless immediate + # read-after-write visibility is required. + consistency_level="STRONG", + metadata_index_keys=["category", "source"], + ) + + msg = f"Unsupported vector store backend: {backend}" + raise ValueError(msg) @action(InputEvent.EVENT_TYPE) @staticmethod @@ -112,16 +168,23 @@ def process_input(event: Event, ctx: RunnerContext) -> None: is_initialized = stm.get("is_initialized") or False if not is_initialized: - print("[TEST] Initializing vector store...") + vector_store = ctx.get_resource( + VECTOR_STORE_RESOURCE, ResourceType.VECTOR_STORE + ) + backend = _vector_store_backend_from_resource(vector_store) + stm.set("vector_store_backend", backend) + test_collection = os.environ.get( + "VECTOR_STORE_TEST_COLLECTION", TEST_COLLECTION + ) + + print(f"[TEST][{backend}] Initializing vector store...") - vector_store = ctx.get_resource("vector_store", ResourceType.VECTOR_STORE) if isinstance(vector_store, CollectionManageableVectorStore): vector_store.create_collection_if_not_exists( - TEST_COLLECTION, metadata={"key1": "value1", "key2": "value2"} + test_collection, metadata={"key1": "value1", "key2": "value2"} ) - vector_store.delete_collection(name=TEST_COLLECTION) - - print("[TEST] Vector store Collection Management PASSED") + vector_store.delete_collection(name=test_collection) + print(f"[TEST][{backend}] Vector store Collection Management PASSED") documents = [ Document( @@ -140,6 +203,9 @@ def process_input(event: Event, ctx: RunnerContext) -> None: metadata={"category": "utility", "source": "test"}, ), ] + + collection = vector_store.collection or DEFAULT_COLLECTION + vector_store.create_collection_if_not_exists(collection) vector_store.add(documents=documents) assert len(vector_store.get()) == 3 @@ -152,7 +218,10 @@ def process_input(event: Event, ctx: RunnerContext) -> None: while len(vector_store.get()) > 2 and retry_time < MAX_RETRIES_TIMES: retry_time += 1 time.sleep(2) - print(f"[TEST] Retrying to delete doc3, retry_time={retry_time}") + print( + f"[TEST][{backend}] Vector store Retrying to delete doc3, " + f"retry_time={retry_time}" + ) assert len(vector_store.get()) == 2 @@ -165,10 +234,10 @@ def process_input(event: Event, ctx: RunnerContext) -> None: == "Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse." ) - print("[TEST] Vector store Document Management PASSED") + print(f"[TEST][{backend}] Vector store Document Management PASSED") # Verify VectorStoreQuery.filters survives the Python->Java bridge. - # Elasticsearch translates the unified-DSL filter to a bool/must term + # Each backend translates the unified-DSL filter to a bool/must term # post-filter, so the result must contain only the doc tagged # ``category=calculate`` (doc1). filtered_query = VectorStoreQuery( @@ -178,7 +247,6 @@ def process_input(event: Event, ctx: RunnerContext) -> None: filters={"category": "calculate"}, ) - # ES is eventually consistent; allow a few retries. retry_time = 0 filtered_docs = vector_store.query(filtered_query).documents while len(filtered_docs) != 1 and retry_time < MAX_RETRIES_TIMES: @@ -192,23 +260,27 @@ def process_input(event: Event, ctx: RunnerContext) -> None: ) assert filtered_docs[0].id == "doc1" - print("[TEST] Vector store filter query PASSED") + print(f"[TEST][{backend}] Vector store filter query PASSED") + else: + msg = "vector_store must implement CollectionManageableVectorStore" + raise TypeError(msg) stm.set("is_initialized", True) ctx.send_event( - ContextRetrievalRequestEvent(query=input_text, vector_store="vector_store") + ContextRetrievalRequestEvent( + query=input_text, vector_store=VECTOR_STORE_RESOURCE + ) ) @action(ContextRetrievalResponseEvent.EVENT_TYPE) @staticmethod - def contextRetrievalResponseEvent( - event: Event, ctx: RunnerContext - ) -> None: + def contextRetrievalResponseEvent(event: Event, ctx: RunnerContext) -> None: """User defined action for processing context retrieval response. In this action, we will test Vector store Context Retrieval. """ + backend = _backend_from_context(ctx) documents = ContextRetrievalResponseEvent.from_event(event).documents assert documents is not None @@ -221,7 +293,9 @@ def contextRetrievalResponseEvent( test_result = f"[PASS] retrieved_count={len(documents)}, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}" print( - f"[TEST] Vector store Context Retrieval PASSED, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}" + f"[TEST][{backend}] Vector store Context Retrieval PASSED, " + f"first_doc_id={documents[0].id}, " + f"first_doc_preview={documents[0].content[:50]}" ) ctx.send_event(OutputEvent(output=test_result)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py index 36825a779..ac49eecdf 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py @@ -16,7 +16,9 @@ # limitations under the License. ################################################################################# import os +import sys import sysconfig +import uuid from pathlib import Path import pytest @@ -44,23 +46,31 @@ os.environ["OLLAMA_EMBEDDING_MODEL"] = OLLAMA_MODEL ES_HOST = os.environ.get("ES_HOST") +MILVUS_URI = os.environ.get("MILVUS_URI") client = pull_model(OLLAMA_MODEL) +EMBEDDING_TYPES = ["JAVA", "PYTHON"] os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] -@pytest.mark.skipif( - client is None or ES_HOST is None, - reason="Ollama client or Elasticsearch host is missing.", -) -@pytest.mark.parametrize("embedding_type", ["JAVA", "PYTHON"]) -def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> None: - os.environ["EMBEDDING_TYPE"] = embedding_type +def _run_vector_store_integration( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + embedding_type: str, + backend: str, +) -> None: + print(f"[TEST][{backend}] Vector store e2e embedding={embedding_type}") + monkeypatch.setenv("EMBEDDING_TYPE", embedding_type) + monkeypatch.setenv("VECTOR_STORE_BACKEND", backend) + suffix = uuid.uuid4().hex + monkeypatch.setenv("VECTOR_STORE_COLLECTION", f"my_documents_{suffix}") + monkeypatch.setenv("VECTOR_STORE_TEST_COLLECTION", f"test_collection_{suffix}") env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) env.set_parallelism(1) + env.set_python_executable(sys.executable) # currently, bounded source is not supported due to runtime implementation, so # we use continuous file source here. @@ -110,5 +120,30 @@ def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> N with file.open() as f: actual_result.extend(f.readlines()) + assert len(actual_result) >= 2 assert "PASS" in actual_result[0] assert "PASS" in actual_result[1] + + +@pytest.mark.skipif( + client is None or ES_HOST is None, + reason="Embedding model client or Elasticsearch host is missing.", +) +@pytest.mark.parametrize("embedding_type", EMBEDDING_TYPES) +def test_elasticsearch_vector_store_integration( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, embedding_type: str +) -> None: + _run_vector_store_integration( + tmp_path, monkeypatch, embedding_type, "ELASTICSEARCH" + ) + + +@pytest.mark.skipif( + client is None or MILVUS_URI is None, + reason="Embedding model client or Milvus URI is missing.", +) +@pytest.mark.parametrize("embedding_type", EMBEDDING_TYPES) +def test_milvus_vector_store_integration( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, embedding_type: str +) -> None: + _run_vector_store_integration(tmp_path, monkeypatch, embedding_type, "MILVUS") diff --git a/python/flink_agents/runtime/java/java_vector_store.py b/python/flink_agents/runtime/java/java_vector_store.py index 923482df2..7c62bfd07 100644 --- a/python/flink_agents/runtime/java/java_vector_store.py +++ b/python/flink_agents/runtime/java/java_vector_store.py @@ -56,6 +56,10 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N """ # embedding_model are required parameters for BaseVectorStore embedding_model = kwargs.pop("embedding_model", "") + # Elasticsearch/OpenSearch call their document container "index"; + # expose it as BaseVectorStore's generic default collection. + if kwargs.get("collection") is None and kwargs.get("index") is not None: + kwargs["collection"] = kwargs["index"] super().__init__(embedding_model=embedding_model, **kwargs) self._j_resource = j_resource diff --git a/tools/docker/elasticsearch/docker-compose.yml b/tools/docker/elasticsearch/docker-compose.yml new file mode 100644 index 000000000..378f85227 --- /dev/null +++ b/tools/docker/elasticsearch/docker-compose.yml @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +services: + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.19.0 + container_name: flink-agents-elasticsearch + environment: + discovery.type: single-node + xpack.security.enabled: "false" + ES_JAVA_OPTS: "-Xms512m -Xmx512m" + ports: + - "9200:9200" + volumes: + - elasticsearch_data:/usr/share/elasticsearch/data + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:9200/_cluster/health || exit 1"] + interval: 10s + timeout: 5s + retries: 10 + start_period: 30s + networks: + - flink-agents-elasticsearch + +volumes: + elasticsearch_data: + +networks: + flink-agents-elasticsearch: + driver: bridge diff --git a/tools/docker/milvus/docker-compose.yml b/tools/docker/milvus/docker-compose.yml new file mode 100644 index 000000000..9a45b471e --- /dev/null +++ b/tools/docker/milvus/docker-compose.yml @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +services: + milvus-etcd: + image: quay.io/coreos/etcd:v3.5.5 + container_name: flink-agents-milvus-etcd + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + command: > + etcd + -advertise-client-urls=http://127.0.0.1:2379 + -listen-client-urls http://0.0.0.0:2379 + --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "--endpoints=http://127.0.0.1:2379", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + volumes: + - milvus_etcd_data:/etcd + networks: + - flink-agents-milvus + + milvus-minio: + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + container_name: flink-agents-milvus-minio + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + command: minio server /minio_data --console-address ":9001" + ports: + - "9000:9000" + - "9001:9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + volumes: + - milvus_minio_data:/minio_data + networks: + - flink-agents-milvus + + milvus-standalone: + image: milvusdb/milvus:v2.6.15 + container_name: flink-agents-milvus-standalone + command: ["milvus", "run", "standalone"] + environment: + ETCD_ENDPOINTS: milvus-etcd:2379 + MINIO_ADDRESS: milvus-minio:9000 + ports: + - "19530:19530" + - "9091:9091" + volumes: + - milvus_data:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + timeout: 20s + retries: 3 + start_period: 90s + depends_on: + - milvus-etcd + - milvus-minio + networks: + - flink-agents-milvus + +volumes: + milvus_etcd_data: + driver: local + milvus_minio_data: + driver: local + milvus_data: + driver: local + +networks: + flink-agents-milvus: + driver: bridge