diff --git a/.gitignore b/.gitignore index f88aa021f..eeb022afe 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ python/flink_agents.egg-info/ python/flink_agents/flink_agents.egg-info/ python/flink_agents/lib/ python/uv.lock +**/superpowers/ diff --git a/docs/yaml-schema.json b/docs/yaml-schema.json new file mode 100644 index 000000000..905fa4b65 --- /dev/null +++ b/docs/yaml-schema.json @@ -0,0 +1,448 @@ +{ + "$defs": { + "ActionSpec": { + "additionalProperties": false, + "description": "An action references a user function and the event types it listens to.\n\n``function`` is written as ``:`` \u2014 the\ncolon separates the Python module (or Java class FQN) from the\nattribute path inside it.\n\nAction signatures are fixed (``(Event, RunnerContext)``), so there is\nno ``parameter_types`` knob \u2014 Python doesn't need it, and the Java\naction signature is determined by the action contract.", + "properties": { + "config": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Config" + }, + "function": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Function" + }, + "listen_to": { + "items": { + "type": "string" + }, + "minItems": 1, + "title": "Listen To", + "type": "array" + }, + "name": { + "title": "Name", + "type": "string" + }, + "type": { + "anyOf": [ + { + "enum": [ + "python", + "java" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Type" + } + }, + "required": [ + "name", + "listen_to" + ], + "title": "ActionSpec", + "type": "object" + }, + "AgentSpec": { + "additionalProperties": false, + "description": "One agent inside a YAML file's ``agents:`` list.\n\nHolds the agent's own resources and actions. Resources/actions declared\nat the file level (siblings of ``agents:``) are merged in by the loader.", + "properties": { + "actions": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/ActionSpec" + }, + { + "type": "string" + } + ] + }, + "title": "Actions", + "type": "array" + }, + "chat_model_connections": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Chat Model Connections", + "type": "array" + }, + "chat_model_setups": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Chat Model Setups", + "type": "array" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Description" + }, + "embedding_model_connections": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Embedding Model Connections", + "type": "array" + }, + "embedding_model_setups": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Embedding Model Setups", + "type": "array" + }, + "mcp_servers": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Mcp Servers", + "type": "array" + }, + "name": { + "title": "Name", + "type": "string" + }, + "prompts": { + "items": { + "$ref": "#/$defs/PromptSpec" + }, + "title": "Prompts", + "type": "array" + }, + "skills": { + "items": { + "$ref": "#/$defs/SkillsSpec" + }, + "title": "Skills", + "type": "array" + }, + "tools": { + "items": { + "$ref": "#/$defs/ToolSpec" + }, + "title": "Tools", + "type": "array" + }, + "vector_stores": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Vector Stores", + "type": "array" + } + }, + "required": [ + "name" + ], + "title": "AgentSpec", + "type": "object" + }, + "DescriptorSpec": { + "additionalProperties": true, + "description": "Schema for any ResourceDescriptor-backed resource.\n\nRequired: ``name`` and ``clazz``. ``type`` selects the implementation\nlanguage (``\"python\"`` or ``\"java\"``; ``None`` means Python). All\nremaining fields are forwarded verbatim to ``ResourceDescriptor`` as\nkwargs (or as the Java wrapper's kwargs when ``type: java``); the\nforwarding and language-aware wrapping is done by ``loader._build_descriptor``.", + "properties": { + "clazz": { + "title": "Clazz", + "type": "string" + }, + "name": { + "title": "Name", + "type": "string" + }, + "type": { + "anyOf": [ + { + "enum": [ + "python", + "java" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Type" + } + }, + "required": [ + "name", + "clazz" + ], + "title": "DescriptorSpec", + "type": "object" + }, + "MessageRole": { + "description": "Role of a message in a chat conversation.", + "enum": [ + "system", + "user", + "assistant", + "tool" + ], + "title": "MessageRole", + "type": "string" + }, + "PromptMessage": { + "additionalProperties": false, + "description": "One message in a multi-turn prompt template.", + "properties": { + "content": { + "title": "Content", + "type": "string" + }, + "role": { + "$ref": "#/$defs/MessageRole", + "default": "user" + } + }, + "required": [ + "content" + ], + "title": "PromptMessage", + "type": "object" + }, + "PromptSpec": { + "additionalProperties": false, + "description": "Declarative prompt: either a single ``text`` template or a list of\nrole-tagged ``messages``. Exactly one of the two fields must be set.", + "properties": { + "messages": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/PromptMessage" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Messages" + }, + "name": { + "title": "Name", + "type": "string" + }, + "text": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Text" + } + }, + "required": [ + "name" + ], + "title": "PromptSpec", + "type": "object" + }, + "SkillsSpec": { + "additionalProperties": false, + "description": "Declarative Skills resource pointing at one or more skill source\ndirectories on the local filesystem.", + "properties": { + "name": { + "title": "Name", + "type": "string" + }, + "paths": { + "items": { + "type": "string" + }, + "title": "Paths", + "type": "array" + } + }, + "required": [ + "name", + "paths" + ], + "title": "SkillsSpec", + "type": "object" + }, + "ToolSpec": { + "additionalProperties": false, + "description": "Points ``function:`` at a callable tool.\n\n``function`` is written as ``:`` \u2014 the\ncolon separates the Python module (or Java class FQN) from the\nattribute path inside it. For Python, the right side may be a\nnested ``Class.method``.\n\n``parameter_types`` is required when ``type: java`` and is ignored\notherwise (Python tools are reflected from the callable signature).\nThe list contains one string per declared parameter of the Java\nmethod, in declaration order \u2014 the loader uses it to disambiguate\noverloaded methods on the Java class. Each string is one of:\n\n- A Java primitive name: one of ``boolean``, ``byte``, ``short``,\n ``int``, ``long``, ``float``, ``double``, ``char``.\n- A fully-qualified Java reference type (including boxed\n primitives), e.g. ``java.lang.Double``, ``java.lang.String``,\n ``java.util.List``.\n\nGeneric type arguments are not part of the JVM method descriptor\nand must not be included (``java.util.List``, not\n``java.util.List``).", + "properties": { + "function": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Function" + }, + "name": { + "title": "Name", + "type": "string" + }, + "parameter_types": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Parameter Types" + }, + "type": { + "anyOf": [ + { + "enum": [ + "python", + "java" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Type" + } + }, + "required": [ + "name" + ], + "title": "ToolSpec", + "type": "object" + } + }, + "additionalProperties": false, + "description": "Top-level YAML document.\n\nAlways wraps one or more agents under ``agents:``. Resources and\nactions declared at the same level as ``agents:`` are shared:\nresources are registered on the environment; actions can be\nreferenced from any agent by name string.", + "properties": { + "actions": { + "items": { + "$ref": "#/$defs/ActionSpec" + }, + "title": "Actions", + "type": "array" + }, + "agents": { + "items": { + "$ref": "#/$defs/AgentSpec" + }, + "title": "Agents", + "type": "array" + }, + "chat_model_connections": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Chat Model Connections", + "type": "array" + }, + "chat_model_setups": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Chat Model Setups", + "type": "array" + }, + "embedding_model_connections": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Embedding Model Connections", + "type": "array" + }, + "embedding_model_setups": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Embedding Model Setups", + "type": "array" + }, + "mcp_servers": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Mcp Servers", + "type": "array" + }, + "prompts": { + "items": { + "$ref": "#/$defs/PromptSpec" + }, + "title": "Prompts", + "type": "array" + }, + "skills": { + "items": { + "$ref": "#/$defs/SkillsSpec" + }, + "title": "Skills", + "type": "array" + }, + "tools": { + "items": { + "$ref": "#/$defs/ToolSpec" + }, + "title": "Tools", + "type": "array" + }, + "vector_stores": { + "items": { + "$ref": "#/$defs/DescriptorSpec" + }, + "title": "Vector Stores", + "type": "array" + } + }, + "required": [ + "agents" + ], + "title": "YamlAgentsDocument", + "type": "object" +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml index 8b2f02429..9553e3bcb 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml @@ -72,4 +72,29 @@ ${flink.version} - + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + + + + \ No newline at end of file diff --git a/python/flink_agents/api/agents/agent.py b/python/flink_agents/api/agents/agent.py index 01e6f8b2f..3a6aed852 100644 --- a/python/flink_agents/api/agents/agent.py +++ b/python/flink_agents/api/agents/agent.py @@ -18,6 +18,7 @@ from abc import ABC from typing import Any, Callable, Dict, List, Tuple +from flink_agents.api.function import Function, PythonFunction from flink_agents.api.resource import ( ResourceDescriptor, ResourceType, @@ -85,7 +86,7 @@ def my_chat_model() -> ResourceDescriptor: """ _actions: Dict[ - str, Tuple[List[str], Callable, Dict[str, Any]] + str, Tuple[List[str], Function, Dict[str, Any] | None] ] _resources: Dict[ResourceType, Dict[str, Any]] @@ -99,7 +100,7 @@ def __init__(self) -> None: @property def actions( self, - ) -> Dict[str, Tuple[List[str], Callable, Dict[str, Any]]]: + ) -> Dict[str, Tuple[List[str], Function, Dict[str, Any] | None]]: """Get added actions.""" return self._actions @@ -112,7 +113,7 @@ def add_action( self, name: str, events: List[str], - func: Callable, + func: Callable | Function, **config: Any, ) -> "Agent": """Add action to agent. @@ -123,8 +124,10 @@ def add_action( The name of the action, should be unique in the same Agent. events : list[str] Type-identifier strings listened by this action. - func : Callable - The function to be executed when receive listened events. + func : Callable | Function + Either a raw Python callable (it will be wrapped as a + ``PythonFunction``) or a pre-built flink-agents ``Function`` + (e.g. from the YAML loader). **config : Any Key named arguments can be used by this action in runtime. @@ -136,6 +139,8 @@ def add_action( if name in self._actions: msg = f"Action {name} already defined" raise ValueError(msg) + if not isinstance(func, Function): + func = PythonFunction.from_callable(func) self._actions[name] = (events, func, config if config else None) return self diff --git a/python/flink_agents/api/execution_environment.py b/python/flink_agents/api/execution_environment.py index 965daa4d4..115c86161 100644 --- a/python/flink_agents/api/execution_environment.py +++ b/python/flink_agents/api/execution_environment.py @@ -17,7 +17,10 @@ ################################################################################# import importlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List + +if TYPE_CHECKING: + from pathlib import Path from importlib_resources import files from pyflink.common import TypeInformation @@ -38,13 +41,14 @@ class AgentBuilder(ABC): """Builder for integrating agent with input and output.""" @abstractmethod - def apply(self, agent: Agent) -> "AgentBuilder": + def apply(self, agent: "Agent | str") -> "AgentBuilder": """Set agent of AgentBuilder. Parameters ---------- - agent : Agent - The agent user defined to run in execution environment. + agent : Agent | str + Either an Agent instance, or the name of an agent registered + on the environment (e.g. by ``load_yaml``). """ @abstractmethod @@ -92,6 +96,7 @@ class AgentsExecutionEnvironment(ABC): """Base class for agent execution environment.""" _resources: Dict[ResourceType, Dict[str, Any]] + _agents: Dict[str, Agent] def __init__(self) -> None: """Init method.""" @@ -99,6 +104,7 @@ def __init__(self) -> None: self._resources = {} for type in ResourceType: self._resources[type] = {} + self._agents: Dict[str, Agent] = {} @property def resources(self) -> Dict[ResourceType, Dict[str, Any]]: @@ -264,3 +270,13 @@ def add_resource( self._resources[resource_type][name] = instance return self + + def load_yaml(self, paths: "Path | str | List[Path | str]") -> None: + """Load one or more YAML files and register their declared agents + and shared resources on this environment. + + See :mod:`flink_agents.api.yaml.loader` for the format reference. + """ + from flink_agents.api.yaml.loader import load_yaml as _load_yaml + + _load_yaml(self, paths) diff --git a/python/flink_agents/api/function.py b/python/flink_agents/api/function.py new file mode 100644 index 000000000..b5597664b --- /dev/null +++ b/python/flink_agents/api/function.py @@ -0,0 +1,109 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Data-only descriptors for user-defined functions. + +These models carry the information needed to *identify* a Python or Java +function: ``module`` and ``qualname`` for Python; declaring class, +method name, and parameter types for Java. +""" + +import importlib +import inspect +from abc import ABC +from typing import Any, Callable, List + +from pydantic import BaseModel, model_serializer + + +class Function(BaseModel, ABC): + """Marker base class for function descriptors. Pure data — has no + ``__call__`` and no executable behavior. + """ + + +class PythonFunction(Function): + """Descriptor for a Python callable: module + qualified name. + + Attributes: + ---------- + module : str + Name of the Python module where the function is defined. + qualname : str + Qualified name of the function (e.g. ``ClassName.method`` for + class methods). + """ + + module: str + qualname: str + + @model_serializer + def __serialize(self) -> dict[str, Any]: + return { + "func_type": self.__class__.__qualname__, + "module": self.module, + "qualname": self.qualname, + } + + @staticmethod + def from_callable(func: Callable) -> "PythonFunction": + """Build a ``PythonFunction`` descriptor from a Python callable.""" + return PythonFunction( + module=inspect.getmodule(func).__name__, + qualname=func.__qualname__, + ) + + def as_callable(self) -> Callable: + """Resolve this descriptor to the underlying Python callable. + + Imports the target module and looks up ``qualname``. Pure Python + reflection — no execution, no JVM. ``ClassName.method`` is split + and resolved through the class attribute. + """ + module = importlib.import_module(self.module) + if "." in self.qualname: + classname, methodname = self.qualname.rsplit(".", 1) + clazz = getattr(module, classname) + return getattr(clazz, methodname) + return getattr(module, self.qualname) + + +class JavaFunction(Function): + """Descriptor for a Java method: class FQN + method name + parameter types. + + Attributes: + ---------- + qualname : str + Fully-qualified name of the declaring Java class. + method_name : str + The Java method name. + parameter_types : List[str] + The Java parameter types, in declaration order. + """ + + qualname: str + method_name: str + parameter_types: List[str] + + @model_serializer + def __serialize(self) -> dict[str, Any]: + return { + "func_type": self.__class__.__qualname__, + "qualname": self.qualname, + "method_name": self.method_name, + "parameter_types": self.parameter_types, + } diff --git a/python/flink_agents/api/tools/function_tool.py b/python/flink_agents/api/tools/function_tool.py new file mode 100644 index 000000000..48ecfb251 --- /dev/null +++ b/python/flink_agents/api/tools/function_tool.py @@ -0,0 +1,33 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing_extensions import override + +from flink_agents.api.function import JavaFunction, PythonFunction +from flink_agents.api.resource import ResourceType, SerializableResource + + +class FunctionTool(SerializableResource): + """Declarative function tool: carries a function descriptor.""" + + func: PythonFunction | JavaFunction + + @classmethod + @override + def resource_type(cls) -> ResourceType: + """Return resource type of class.""" + return ResourceType.TOOL diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index a2ee041a5..1dd6ac130 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -18,14 +18,17 @@ import typing from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Type +from typing import TYPE_CHECKING, Any, Type -from pydantic import BaseModel, Field, field_serializer, model_validator +from pydantic import BaseModel, field_serializer, model_validator from typing_extensions import override from flink_agents.api.resource import ResourceType, SerializableResource from flink_agents.api.tools.utils import create_model_from_schema +if TYPE_CHECKING: + from flink_agents.api.tools.function_tool import FunctionTool + class ToolType(Enum): """Tool type enum. @@ -99,34 +102,18 @@ def get_parameters_dict(self) -> dict: return parameters -class FunctionTool(SerializableResource): - """Tool container keeps a callable, mainly used to represent - a function which will be converted to BaseTool after compile. - """ - - func: typing.Callable = Field(exclude=True) - - @classmethod - def resource_type(cls) -> ResourceType: - """Get the resource type.""" - return ResourceType.TOOL - - class Tool(SerializableResource, ABC): - """Base abstract class of all kinds of tools. + """Base abstract class of all kinds of tools.""" - Attributes: - ---------- - metadata : ToolMetadata - The metadata of the tools, includes name, description and arguments schema. - """ - - metadata: ToolMetadata + metadata: ToolMetadata | None = None @staticmethod - def from_callable(func: typing.Callable) -> FunctionTool: - """Create a function tool from a callable.""" - return FunctionTool(func=func) + def from_callable(func: typing.Callable) -> "FunctionTool": + """Wrap a Python callable as a declarative ``FunctionTool``.""" + from flink_agents.api.function import PythonFunction + from flink_agents.api.tools.function_tool import FunctionTool + + return FunctionTool(func=PythonFunction.from_callable(func)) @property def name(self) -> str: diff --git a/python/flink_agents/api/yaml/__init__.py b/python/flink_agents/api/yaml/__init__.py new file mode 100644 index 000000000..e154fadd3 --- /dev/null +++ b/python/flink_agents/api/yaml/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/yaml/aliases.py b/python/flink_agents/api/yaml/aliases.py new file mode 100644 index 000000000..3c80c1b09 --- /dev/null +++ b/python/flink_agents/api/yaml/aliases.py @@ -0,0 +1,154 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Static alias tables for the YAML loader. + +Two tables: +- ``EVENT_ALIASES`` maps short event names to ``EVENT_TYPE`` constants. +- ``CLAZZ_ALIASES`` maps short provider names to fully-qualified class + paths. The bucket is keyed on the resource type *and* the + implementation language so the same alias (``ollama``) can refer to + different classes across sections and languages. + +For Java resources, the loader resolves the alias to the Java FQN and +wraps it in a Python-side wrapper class (see ``JAVA_WRAPPER_CLAZZ``). +""" + +from typing import Dict + +from flink_agents.api.events.chat_event import ( + ChatRequestEvent, + ChatResponseEvent, +) +from flink_agents.api.events.context_retrieval_event import ( + ContextRetrievalRequestEvent, + ContextRetrievalResponseEvent, +) +from flink_agents.api.events.event import InputEvent, OutputEvent +from flink_agents.api.events.tool_event import ( + ToolRequestEvent, + ToolResponseEvent, +) +from flink_agents.api.resource import ResourceName, ResourceType +from flink_agents.api.yaml.specs import Language + +EVENT_ALIASES: Dict[str, str] = { + "input": InputEvent.EVENT_TYPE, + "output": OutputEvent.EVENT_TYPE, + "chat_request": ChatRequestEvent.EVENT_TYPE, + "chat_response": ChatResponseEvent.EVENT_TYPE, + "tool_request": ToolRequestEvent.EVENT_TYPE, + "tool_response": ToolResponseEvent.EVENT_TYPE, + "context_retrieval_request": ContextRetrievalRequestEvent.EVENT_TYPE, + "context_retrieval_response": ContextRetrievalResponseEvent.EVENT_TYPE, +} + +# resource_type -> language -> alias -> fully-qualified class path +CLAZZ_ALIASES: Dict[ResourceType, Dict[str, Dict[str, str]]] = { + ResourceType.CHAT_MODEL_CONNECTION: { + "python": { + "ollama": ResourceName.ChatModel.OLLAMA_CONNECTION, + "openai": ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION, + "anthropic": ResourceName.ChatModel.ANTHROPIC_CONNECTION, + "tongyi": ResourceName.ChatModel.TONGYI_CONNECTION, + "azure_openai": ResourceName.ChatModel.AZURE_OPENAI_CONNECTION, + }, + "java": { + "ollama": ResourceName.ChatModel.Java.OLLAMA_CONNECTION, + "openai_completions": ResourceName.ChatModel.Java.OPENAI_COMPLETIONS_CONNECTION, + "openai_responses": ResourceName.ChatModel.Java.OPENAI_RESPONSES_CONNECTION, + "anthropic": ResourceName.ChatModel.Java.ANTHROPIC_CONNECTION, + "azure": ResourceName.ChatModel.Java.AZURE_CONNECTION, + }, + }, + ResourceType.CHAT_MODEL: { + "python": { + "ollama": ResourceName.ChatModel.OLLAMA_SETUP, + "openai": ResourceName.ChatModel.OPENAI_COMPLETIONS_SETUP, + "anthropic": ResourceName.ChatModel.ANTHROPIC_SETUP, + "tongyi": ResourceName.ChatModel.TONGYI_SETUP, + "azure_openai": ResourceName.ChatModel.AZURE_OPENAI_SETUP, + }, + "java": { + "ollama": ResourceName.ChatModel.Java.OLLAMA_SETUP, + "openai_completions": ResourceName.ChatModel.Java.OPENAI_COMPLETIONS_SETUP, + "openai_responses": ResourceName.ChatModel.Java.OPENAI_RESPONSES_SETUP, + "anthropic": ResourceName.ChatModel.Java.ANTHROPIC_SETUP, + "azure": ResourceName.ChatModel.Java.AZURE_SETUP, + }, + }, + ResourceType.EMBEDDING_MODEL_CONNECTION: { + "python": { + "ollama": ResourceName.EmbeddingModel.OLLAMA_CONNECTION, + "openai": ResourceName.EmbeddingModel.OPENAI_CONNECTION, + "tongyi": ResourceName.EmbeddingModel.TONGYI_CONNECTION, + }, + "java": { + "ollama": ResourceName.EmbeddingModel.Java.OLLAMA_CONNECTION, + }, + }, + ResourceType.EMBEDDING_MODEL: { + "python": { + "ollama": ResourceName.EmbeddingModel.OLLAMA_SETUP, + "openai": ResourceName.EmbeddingModel.OPENAI_SETUP, + "tongyi": ResourceName.EmbeddingModel.TONGYI_SETUP, + }, + "java": { + "ollama": ResourceName.EmbeddingModel.Java.OLLAMA_SETUP, + }, + }, + ResourceType.VECTOR_STORE: { + "python": { + "chroma": ResourceName.VectorStore.CHROMA_VECTOR_STORE, + }, + "java": { + "elasticsearch": ResourceName.VectorStore.Java.ELASTICSEARCH_VECTOR_STORE, + }, + }, +} + +# Python wrapper class for each cross-language-supported resource type. +# When the user writes ``type: java``, the loader resolves the alias in +# the java bucket to a Java FQN and constructs a ResourceDescriptor whose +# ``clazz`` is the wrapper below and whose ``java_clazz`` kwarg is the +# resolved Java FQN. +JAVA_WRAPPER_CLAZZ: Dict[ResourceType, str] = { + ResourceType.CHAT_MODEL_CONNECTION: ResourceName.ChatModel.JAVA_WRAPPER_CONNECTION, + ResourceType.CHAT_MODEL: ResourceName.ChatModel.JAVA_WRAPPER_SETUP, + ResourceType.EMBEDDING_MODEL_CONNECTION: ResourceName.EmbeddingModel.JAVA_WRAPPER_CONNECTION, + ResourceType.EMBEDDING_MODEL: ResourceName.EmbeddingModel.JAVA_WRAPPER_SETUP, + ResourceType.VECTOR_STORE: ResourceName.VectorStore.JAVA_WRAPPER_VECTOR_STORE, +} + + +def resolve_event_type(name: str) -> str: + """Replace an event alias with its fully-qualified event type string, + or pass through if no alias matches. + """ + return EVENT_ALIASES.get(name, name) + + +def resolve_clazz( + name: str, resource_type: ResourceType, language: Language = "python" +) -> str: + """Look up ``name`` in the alias bucket for ``(resource_type, language)``. + + Returns the fully-qualified class path on hit, or ``name`` unchanged + on miss (so users can supply a fully-qualified class path directly). + """ + bucket = CLAZZ_ALIASES.get(resource_type, {}).get(language, {}) + return bucket.get(name, name) diff --git a/python/flink_agents/api/yaml/loader.py b/python/flink_agents/api/yaml/loader.py new file mode 100644 index 000000000..599220533 --- /dev/null +++ b/python/flink_agents/api/yaml/loader.py @@ -0,0 +1,387 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""YAML loader: parse a YAML document and register agents on an execution +environment. +""" + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +if TYPE_CHECKING: + from flink_agents.api.execution_environment import AgentsExecutionEnvironment + +import yaml + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.function import Function, JavaFunction, PythonFunction +from flink_agents.api.prompts.prompt import Prompt +from flink_agents.api.resource import ResourceDescriptor, ResourceType +from flink_agents.api.skills import Skills +from flink_agents.api.tools.function_tool import FunctionTool +from flink_agents.api.yaml.aliases import ( + JAVA_WRAPPER_CLAZZ, + resolve_clazz, + resolve_event_type, +) +from flink_agents.api.yaml.specs import ( + ActionSpec, + AgentSpec, + DescriptorSpec, + Language, + PromptSpec, + SkillsSpec, + ToolSpec, + YamlAgentsDocument, +) + +# Default Java parameter types for an action. Action methods in +# flink-agents always have signature (Event, RunnerContext). +_JAVA_ACTION_PARAMETER_TYPES: list[str] = [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", +] + +_DESCRIPTOR_TYPES: Dict[str, ResourceType] = { + "chat_model_connections": ResourceType.CHAT_MODEL_CONNECTION, + "chat_model_setups": ResourceType.CHAT_MODEL, + "embedding_model_connections": ResourceType.EMBEDDING_MODEL_CONNECTION, + "embedding_model_setups": ResourceType.EMBEDDING_MODEL, + "vector_stores": ResourceType.VECTOR_STORE, + "mcp_servers": ResourceType.MCP_SERVER, +} + + +def resolve_function( + *, + name: str, + function: str | None, + language: Language | None = None, + parameter_types: List[str] | None = None, +) -> PythonFunction | JavaFunction: + """Resolve a YAML function reference to a flink-agents Function. + + Returns a ``PythonFunction`` when ``language`` is ``"python"`` (or + None — the default). Returns a ``JavaFunction`` when ``language`` + is ``"java"``. Java parameter types must be passed in by the caller + (actions use a fixed signature; tools vary per method). + + ``function`` must be ``:`` — a colon separates the + module (or Java class FQN) from the attribute path inside it: + + - Python: ``flink_agents.tools:add`` or + ``flink_agents.tools:MyTools.add`` (the right side is the + ``PythonFunction.qualname``, so nested ``Class.method`` is fine). + - Java: ``com.example.MyClass:method`` (or + ``com.example.Outer$Inner:method`` for inner classes). + + The colon is what lets a cross-language YAML loader recognise the + module/class boundary without language-specific import probing. + """ + if function is None: + msg = ( + f"Action/tool {name!r}: 'function' is required and must be " + "of the form ':'." + ) + raise ValueError(msg) + + parts = function.split(":") + if len(parts) != 2 or not parts[0] or not parts[1]: + kind = "java" if language == "java" else "python" + msg = ( + f"Action/tool {name!r}: {kind} function {function!r} must be " + "of the form ':' (e.g. " + "'pkg.tools:add', 'pkg.tools:MyTools.add', " + "'com.example.MyClass:method')." + ) + raise ValueError(msg) + left, right = parts + + if language == "java": + return JavaFunction( + qualname=left, + method_name=right, + parameter_types=parameter_types or [], + ) + return PythonFunction(module=left, qualname=right) + + +def _load_document(path: Path | str) -> YamlAgentsDocument: + text = Path(path).read_text() + raw = yaml.load(text, Loader=yaml.SafeLoader) + if raw is None: + msg = f"YAML file {path} is empty" + raise ValueError(msg) + return YamlAgentsDocument.model_validate(raw) + + +def _build_descriptor( + spec: DescriptorSpec, resource_type: ResourceType +) -> ResourceDescriptor: + kwargs = dict(spec.model_extra or {}) + if spec.type == "java": + if resource_type not in JAVA_WRAPPER_CLAZZ: + msg = ( + f"Resource {spec.name!r}: type='java' is not supported " + f"for {resource_type.value} (no Python-side Java wrapper)." + ) + raise ValueError(msg) + java_fqn = resolve_clazz(spec.clazz, resource_type, "java") + wrapper_clazz = JAVA_WRAPPER_CLAZZ[resource_type] + return ResourceDescriptor(clazz=wrapper_clazz, java_clazz=java_fqn, **kwargs) + python_fqn = resolve_clazz(spec.clazz, resource_type, "python") + return ResourceDescriptor(clazz=python_fqn, **kwargs) + + +def _add_descriptors_to_agent( + agent: Agent, attr_name: str, descriptors: list[DescriptorSpec] +) -> None: + resource_type = _DESCRIPTOR_TYPES[attr_name] + for spec in descriptors: + agent.add_resource( + spec.name, resource_type, _build_descriptor(spec, resource_type) + ) + + +def _resolve_action_function(action: ActionSpec) -> Function: + parameter_types = _JAVA_ACTION_PARAMETER_TYPES if action.type == "java" else None + return resolve_function( + name=action.name, + function=action.function, + language=action.type, + parameter_types=parameter_types, + ) + + +def _add_action_to_agent(agent: Agent, action: ActionSpec) -> None: + func = _resolve_action_function(action) + events = [resolve_event_type(e) for e in action.listen_to] + config = action.config or {} + agent.add_action(action.name, events, func, **config) + + +def _build_tool(spec: ToolSpec) -> FunctionTool: + if spec.type == "java" and spec.parameter_types is None: + msg = f"Tool {spec.name!r}: java tools must declare 'parameter_types' in YAML." + raise ValueError(msg) + func = resolve_function( + name=spec.name, + function=spec.function, + language=spec.type, + parameter_types=spec.parameter_types, + ) + return FunctionTool(func=func) + + +def _build_prompt(spec: PromptSpec) -> Prompt: + if spec.text is not None: + return Prompt.from_text(spec.text) + messages = [ + ChatMessage(role=MessageRole(m.role.value), content=m.content) + for m in (spec.messages or []) + ] + return Prompt.from_messages(messages) + + +def _build_skills(spec: SkillsSpec) -> Skills: + return Skills(paths=list(spec.paths)) + + +def _build_agent(agent_spec: AgentSpec) -> Agent: + agent = Agent() + for attr in _DESCRIPTOR_TYPES: + descriptors = getattr(agent_spec, attr) + _add_descriptors_to_agent(agent, attr, descriptors) + for tool_spec in agent_spec.tools: + agent.add_resource( + tool_spec.name, + ResourceType.TOOL, + _build_tool(tool_spec), + ) + for prompt_spec in agent_spec.prompts: + agent.add_resource( + prompt_spec.name, ResourceType.PROMPT, _build_prompt(prompt_spec) + ) + for skills_spec in agent_spec.skills: + agent.add_resource( + skills_spec.name, ResourceType.SKILLS, _build_skills(skills_spec) + ) + for action in agent_spec.actions: + if isinstance(action, str): + continue # shared-action references handled by caller + _add_action_to_agent(agent, action) + return agent + + +def _build_in_file_state( + path: Path | str, +) -> Tuple[ + Dict[str, Agent], + Dict[ResourceType, Dict[str, Any]], + Dict[str, ActionSpec], + Dict[str, AgentSpec], + YamlAgentsDocument, +]: + """Parse one YAML file, perform in-file duplicate detection, and build + the in-memory state without touching any execution environment. + + Returns: + agents: name -> Agent + shared_resources: resource_type -> name -> descriptor/resource + shared_actions: name -> ActionSpec (file-level, for cross-agent reference) + agent_specs: name -> AgentSpec (kept so callers can resolve string + action references back to the originating spec). + + Both :func:`build_agents` and :func:`load_yaml` go through this helper + so the in-file rules (duplicate detection, build order) are defined in + exactly one place. + """ + doc = _load_document(path) + agent_specs: Dict[str, AgentSpec] = {} + agents: Dict[str, Agent] = {} + for spec in doc.agents: + if spec.name in agents: + msg = f"Duplicate agent name {spec.name!r} in {path}" + raise ValueError(msg) + agent_specs[spec.name] = spec + agents[spec.name] = _build_agent(spec) + + shared_resources: Dict[ResourceType, Dict[str, Any]] = {t: {} for t in ResourceType} + for attr, resource_type in _DESCRIPTOR_TYPES.items(): + for spec in getattr(doc, attr): + if spec.name in shared_resources[resource_type]: + msg = f"Duplicate shared resource name {spec.name!r} in {path}" + raise ValueError(msg) + shared_resources[resource_type][spec.name] = _build_descriptor( + spec, resource_type + ) + for tool_spec in doc.tools: + if tool_spec.name in shared_resources[ResourceType.TOOL]: + msg = f"Duplicate shared tool name {tool_spec.name!r} in {path}" + raise ValueError(msg) + shared_resources[ResourceType.TOOL][tool_spec.name] = _build_tool(tool_spec) + for prompt_spec in doc.prompts: + if prompt_spec.name in shared_resources[ResourceType.PROMPT]: + msg = f"Duplicate shared prompt name {prompt_spec.name!r} in {path}" + raise ValueError(msg) + shared_resources[ResourceType.PROMPT][prompt_spec.name] = _build_prompt( + prompt_spec + ) + for skills_spec in doc.skills: + if skills_spec.name in shared_resources[ResourceType.SKILLS]: + msg = f"Duplicate shared skills name {skills_spec.name!r} in {path}" + raise ValueError(msg) + shared_resources[ResourceType.SKILLS][skills_spec.name] = _build_skills( + skills_spec + ) + + shared_actions: Dict[str, ActionSpec] = {} + for action_spec in doc.actions: + if action_spec.name in shared_actions: + msg = f"Duplicate shared action name {action_spec.name!r} in {path}" + raise ValueError(msg) + shared_actions[action_spec.name] = action_spec + + return agents, shared_resources, shared_actions, agent_specs, doc + + +def build_agents( + path: Path | str, +) -> Tuple[Dict[str, Agent], Dict[ResourceType, Dict[str, Any]], Dict[str, ActionSpec]]: + """Parse one YAML file and build the agents it declares. + + Returns: + agents: name -> Agent + shared_resources: resource_type -> name -> descriptor/resource + shared_actions: name -> ActionSpec (file-level, for cross-agent reference) + + This function only handles in-file structure. It does NOT enforce + cross-file duplicate detection — that's the caller's job. + """ + agents, shared_resources, shared_actions, _, _ = _build_in_file_state(path) + return agents, shared_resources, shared_actions + + +def _resolve_shared_action_refs( + agents: Dict[str, "Agent"], + agent_specs: Dict[str, AgentSpec], + shared_actions: Dict[str, ActionSpec], + path: "Path | str", +) -> None: + """For each agent, replace any string action reference with a copy of + the shared action. + """ + for agent_name, agent in agents.items(): + spec = agent_specs[agent_name] + for item in spec.actions: + if not isinstance(item, str): + continue + if item not in shared_actions: + msg = ( + f"Agent {agent_name!r} references shared action " + f"{item!r} in {path}, but no shared action with that " + "name is defined at the file level." + ) + raise ValueError(msg) + shared = shared_actions[item] + _add_action_to_agent(agent, shared) + + +def load_yaml( + env: "AgentsExecutionEnvironment", + paths: Path | str | List[Path | str], +) -> None: + """Load one or more YAML files and register their agents and shared + resources on the environment. + + Multiple calls accumulate. Duplicate names — both within a single file + and across the current environment — raise ``ValueError``. In-file + duplicate detection is delegated to :func:`_build_in_file_state` so + that ``load_yaml`` and :func:`build_agents` share the same rules. + """ + if isinstance(paths, str | Path): + paths = [paths] + + for path in paths: + agents, shared_resources, shared_actions, agent_specs, _ = ( + _build_in_file_state(path) + ) + + # Cross-environment duplicate checks. In-file duplicates were + # already caught inside ``_build_in_file_state``. + for name in agents: + if name in env._agents: + msg = f"Duplicate agent name {name!r} (loading {path})" + raise ValueError(msg) + for resource_type, name_to_resource in shared_resources.items(): + for name in name_to_resource: + if name in env.resources[resource_type]: + msg = ( + f"Duplicate shared {resource_type.value} {name!r} " + f"(loading {path})" + ) + raise ValueError(msg) + + # Resolve string action refs (raises ValueError on unknown ref). + _resolve_shared_action_refs(agents, agent_specs, shared_actions, path) + + # Commit: write resources then agents to env. + for resource_type, name_to_resource in shared_resources.items(): + for name, resource in name_to_resource.items(): + env.add_resource(name, resource_type, resource) + env._agents.update(agents) diff --git a/python/flink_agents/api/yaml/specs.py b/python/flink_agents/api/yaml/specs.py new file mode 100644 index 000000000..f806bf29d --- /dev/null +++ b/python/flink_agents/api/yaml/specs.py @@ -0,0 +1,224 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Pydantic schema for the declarative YAML API. + +The models in this module define the file-level wire format. Pydantic +validation is the ground truth for the JSON Schema published in +docs/yaml-schema.json. +""" + +import json +import sys +from enum import Enum +from typing import Any, Dict, List, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +Language = Literal["python", "java"] +"""Implementation language of a YAML-declared resource, action, or tool.""" + + +class DescriptorSpec(BaseModel): + """Schema for any ResourceDescriptor-backed resource. + + Required: ``name`` and ``clazz``. ``type`` selects the implementation + language (``"python"`` or ``"java"``; ``None`` means Python). All + remaining fields are forwarded verbatim to ``ResourceDescriptor`` as + kwargs (or as the Java wrapper's kwargs when ``type: java``); the + forwarding and language-aware wrapping is done by ``loader._build_descriptor``. + """ + + model_config = ConfigDict(extra="allow") + + name: str + clazz: str + type: Language | None = None + + +class MessageRole(str, Enum): + """Role of a message in a chat conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +class PromptMessage(BaseModel): + """One message in a multi-turn prompt template.""" + + model_config = ConfigDict(extra="forbid") + + role: MessageRole = MessageRole.USER + content: str + + +class PromptSpec(BaseModel): + """Declarative prompt: either a single ``text`` template or a list of + role-tagged ``messages``. Exactly one of the two fields must be set. + """ + + model_config = ConfigDict(extra="forbid") + + name: str + text: str | None = None + messages: List[PromptMessage] | None = None + + @model_validator(mode="after") + def _require_exactly_one(self) -> "PromptSpec": + # Treat empty string / empty list as "unset" so that ``text: ""`` and + # ``messages: []`` are rejected rather than silently producing a + # nonsense empty prompt at load time. + if bool(self.text) == bool(self.messages): + msg = "prompt must define exactly one non-empty 'text' or 'messages'" + raise ValueError(msg) + return self + + +class ToolSpec(BaseModel): + """Points ``function:`` at a callable tool. + + ``function`` is written as ``:`` — the + colon separates the Python module (or Java class FQN) from the + attribute path inside it. For Python, the right side may be a + nested ``Class.method``. + + ``parameter_types`` is required when ``type: java`` and is ignored + otherwise (Python tools are reflected from the callable signature). + The list contains one string per declared parameter of the Java + method, in declaration order — the loader uses it to disambiguate + overloaded methods on the Java class. Each string is one of: + + - A Java primitive name: one of ``boolean``, ``byte``, ``short``, + ``int``, ``long``, ``float``, ``double``, ``char``. + - A fully-qualified Java reference type (including boxed + primitives), e.g. ``java.lang.Double``, ``java.lang.String``, + ``java.util.List``. + + Generic type arguments are not part of the JVM method descriptor + and must not be included (``java.util.List``, not + ``java.util.List``). + """ + + model_config = ConfigDict(extra="forbid") + + name: str + function: str | None = None + type: Language | None = None + parameter_types: List[str] | None = None + + +class SkillsSpec(BaseModel): + """Declarative Skills resource pointing at one or more skill source + directories on the local filesystem. + """ + + model_config = ConfigDict(extra="forbid") + + name: str + paths: List[str] + + +class ActionSpec(BaseModel): + """An action references a user function and the event types it listens to. + + ``function`` is written as ``:`` — the + colon separates the Python module (or Java class FQN) from the + attribute path inside it. + + Action signatures are fixed (``(Event, RunnerContext)``), so there is + no ``parameter_types`` knob — Python doesn't need it, and the Java + action signature is determined by the action contract. + """ + + model_config = ConfigDict(extra="forbid") + + name: str + function: str | None = None + listen_to: List[str] = Field(..., min_length=1) + config: Dict[str, Any] | None = None + type: Language | None = None + + +class AgentSpec(BaseModel): + """One agent inside a YAML file's ``agents:`` list. + + Holds the agent's own resources and actions. Resources/actions declared + at the file level (siblings of ``agents:``) are merged in by the loader. + """ + + model_config = ConfigDict(extra="forbid") + + name: str + description: str | None = None + + prompts: List[PromptSpec] = Field(default_factory=list) + tools: List[ToolSpec] = Field(default_factory=list) + skills: List[SkillsSpec] = Field(default_factory=list) + actions: List[ActionSpec | str] = Field(default_factory=list) + + chat_model_connections: List[DescriptorSpec] = Field(default_factory=list) + chat_model_setups: List[DescriptorSpec] = Field(default_factory=list) + embedding_model_connections: List[DescriptorSpec] = Field(default_factory=list) + embedding_model_setups: List[DescriptorSpec] = Field(default_factory=list) + vector_stores: List[DescriptorSpec] = Field(default_factory=list) + mcp_servers: List[DescriptorSpec] = Field(default_factory=list) + + +class YamlAgentsDocument(BaseModel): + """Top-level YAML document. + + Always wraps one or more agents under ``agents:``. Resources and + actions declared at the same level as ``agents:`` are shared: + resources are registered on the environment; actions can be + referenced from any agent by name string. + """ + + model_config = ConfigDict(extra="forbid") + + agents: List[AgentSpec] + + prompts: List[PromptSpec] = Field(default_factory=list) + tools: List[ToolSpec] = Field(default_factory=list) + skills: List[SkillsSpec] = Field(default_factory=list) + actions: List[ActionSpec] = Field(default_factory=list) + + chat_model_connections: List[DescriptorSpec] = Field(default_factory=list) + chat_model_setups: List[DescriptorSpec] = Field(default_factory=list) + embedding_model_connections: List[DescriptorSpec] = Field(default_factory=list) + embedding_model_setups: List[DescriptorSpec] = Field(default_factory=list) + vector_stores: List[DescriptorSpec] = Field(default_factory=list) + mcp_servers: List[DescriptorSpec] = Field(default_factory=list) + + +def export() -> str: + """Return the JSON Schema for the YAML API as a string. + + Pydantic models in this module are the ground truth for the YAML + file format; this helper serialises them so downstream consumers + that can't read Python types directly (IDE YAML language servers, + a future Java-side loader, generated docs) can use the same + contract. The output is checked in at ``docs/yaml-schema.json``; + keep it in sync by re-running this helper after editing the specs. + """ + schema = YamlAgentsDocument.model_json_schema() + return json.dumps(schema, indent=2, sort_keys=True) + "\n" + + +if __name__ == "__main__": + sys.stdout.write(export()) diff --git a/python/flink_agents/api/yaml/tests/__init__.py b/python/flink_agents/api/yaml/tests/__init__.py new file mode 100644 index 000000000..e154fadd3 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/yaml/tests/fixtures/__init__.py b/python/flink_agents/api/yaml/tests/fixtures/__init__.py new file mode 100644 index 000000000..e154fadd3 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/yaml/tests/fixtures/loader_targets.py b/python/flink_agents/api/yaml/tests/fixtures/loader_targets.py new file mode 100644 index 000000000..665434797 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/loader_targets.py @@ -0,0 +1,46 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Module-level callables referenced by YAML fixture files.""" + +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext + + +def increment(event: Event, ctx: RunnerContext) -> None: + value = InputEvent.from_event(event).input + ctx.send_event(OutputEvent(output=value + 1)) + + +def decrement(event: Event, ctx: RunnerContext) -> None: + value = InputEvent.from_event(event).input + ctx.send_event(OutputEvent(output=value - 1)) + + +def notify(id: str, message: str) -> str: + return f"notified {id}: {message}" + + +class Counter: + """Holder for a class-method action target — exercises the + ``module:Class.method`` form in YAML function references. + """ + + @staticmethod + def bump(event: Event, ctx: RunnerContext) -> None: + value = InputEvent.from_event(event).input + ctx.send_event(OutputEvent(output=value + 100)) diff --git a/python/flink_agents/api/yaml/tests/fixtures/multi_agent.yaml b/python/flink_agents/api/yaml/tests/fixtures/multi_agent.yaml new file mode 100644 index 000000000..5d8a30557 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/multi_agent.yaml @@ -0,0 +1,11 @@ +agents: + - name: a1 + actions: + - name: increment + function: flink_agents.api.yaml.tests.fixtures.loader_targets:increment + listen_to: [input] + - name: a2 + actions: + - name: decrement + function: flink_agents.api.yaml.tests.fixtures.loader_targets:decrement + listen_to: [input] \ No newline at end of file diff --git a/python/flink_agents/api/yaml/tests/fixtures/multi_file_a.yaml b/python/flink_agents/api/yaml/tests/fixtures/multi_file_a.yaml new file mode 100644 index 000000000..b876e39d7 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/multi_file_a.yaml @@ -0,0 +1,10 @@ +agents: + - name: file_a_agent + actions: + - name: increment + function: flink_agents.api.yaml.tests.fixtures.loader_targets:increment + listen_to: [input] +chat_model_connections: + - name: conn_from_a + clazz: ollama + base_url: http://a diff --git a/python/flink_agents/api/yaml/tests/fixtures/multi_file_b.yaml b/python/flink_agents/api/yaml/tests/fixtures/multi_file_b.yaml new file mode 100644 index 000000000..d50f3c43f --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/multi_file_b.yaml @@ -0,0 +1,10 @@ +agents: + - name: file_b_agent + actions: + - name: decrement + function: flink_agents.api.yaml.tests.fixtures.loader_targets:decrement + listen_to: [input] +chat_model_connections: + - name: conn_from_b + clazz: ollama + base_url: http://b diff --git a/python/flink_agents/api/yaml/tests/fixtures/single_agent.yaml b/python/flink_agents/api/yaml/tests/fixtures/single_agent.yaml new file mode 100644 index 000000000..e273feadb --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/single_agent.yaml @@ -0,0 +1,6 @@ +agents: + - name: incrementer + actions: + - name: increment + function: flink_agents.api.yaml.tests.fixtures.loader_targets:increment + listen_to: [input] \ No newline at end of file diff --git a/python/flink_agents/api/yaml/tests/fixtures/with_descriptors.yaml b/python/flink_agents/api/yaml/tests/fixtures/with_descriptors.yaml new file mode 100644 index 000000000..d43d5ad3a --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/with_descriptors.yaml @@ -0,0 +1,14 @@ +agents: + - name: chat_agent + actions: + - name: increment + function: flink_agents.api.yaml.tests.fixtures.loader_targets:increment + listen_to: [input] + - name: decrement + function: flink_agents.api.yaml.tests.fixtures.loader_targets:decrement + listen_to: [chat_response] + chat_model_connections: + - name: ollama_conn + clazz: ollama + base_url: http://localhost:11434 + request_timeout: 30 diff --git a/python/flink_agents/api/yaml/tests/fixtures/with_shared.yaml b/python/flink_agents/api/yaml/tests/fixtures/with_shared.yaml new file mode 100644 index 000000000..26b0995cd --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/with_shared.yaml @@ -0,0 +1,20 @@ +agents: + - name: a1 + actions: + - shared_inc + - name: own_dec + function: flink_agents.api.yaml.tests.fixtures.loader_targets:decrement + listen_to: [chat_response] + - name: a2 + actions: + - shared_inc + +chat_model_connections: + - name: shared_conn + clazz: ollama + base_url: http://example + +actions: + - name: shared_inc + function: flink_agents.api.yaml.tests.fixtures.loader_targets:increment + listen_to: [input] \ No newline at end of file diff --git a/python/flink_agents/api/yaml/tests/fixtures/with_skills.yaml b/python/flink_agents/api/yaml/tests/fixtures/with_skills.yaml new file mode 100644 index 000000000..3aaa9ea85 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/with_skills.yaml @@ -0,0 +1,12 @@ +agents: + - name: skills_agent + skills: + - name: agent_skills + paths: + - ./agent_skill_dir + +skills: + - name: shared_skills + paths: + - ./shared_skill_dir + - ./more diff --git a/python/flink_agents/api/yaml/tests/fixtures/with_tools_and_prompts.yaml b/python/flink_agents/api/yaml/tests/fixtures/with_tools_and_prompts.yaml new file mode 100644 index 000000000..27e0b827c --- /dev/null +++ b/python/flink_agents/api/yaml/tests/fixtures/with_tools_and_prompts.yaml @@ -0,0 +1,12 @@ +agents: + - name: tool_agent + tools: + - name: notify + function: flink_agents.api.yaml.tests.fixtures.loader_targets:notify + prompts: + - name: text_prompt + text: "hello {name}" + - name: messages_prompt + messages: + - {role: system, content: "be brief"} + - {role: user, content: "{q}"} diff --git a/python/flink_agents/api/yaml/tests/test_aliases.py b/python/flink_agents/api/yaml/tests/test_aliases.py new file mode 100644 index 000000000..b43db0d40 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/test_aliases.py @@ -0,0 +1,141 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.api.events.chat_event import ( + ChatRequestEvent, + ChatResponseEvent, +) +from flink_agents.api.events.context_retrieval_event import ( + ContextRetrievalRequestEvent, + ContextRetrievalResponseEvent, +) +from flink_agents.api.events.event import InputEvent, OutputEvent +from flink_agents.api.events.tool_event import ( + ToolRequestEvent, + ToolResponseEvent, +) +from flink_agents.api.resource import ResourceType +from flink_agents.api.yaml.aliases import ( + CLAZZ_ALIASES, + EVENT_ALIASES, + JAVA_WRAPPER_CLAZZ, + resolve_clazz, + resolve_event_type, +) + + +def test_event_aliases_map_to_real_event_types() -> None: + assert EVENT_ALIASES["input"] == InputEvent.EVENT_TYPE + assert EVENT_ALIASES["output"] == OutputEvent.EVENT_TYPE + assert EVENT_ALIASES["chat_request"] == ChatRequestEvent.EVENT_TYPE + assert EVENT_ALIASES["chat_response"] == ChatResponseEvent.EVENT_TYPE + assert EVENT_ALIASES["tool_request"] == ToolRequestEvent.EVENT_TYPE + assert EVENT_ALIASES["tool_response"] == ToolResponseEvent.EVENT_TYPE + assert ( + EVENT_ALIASES["context_retrieval_request"] + == ContextRetrievalRequestEvent.EVENT_TYPE + ) + assert ( + EVENT_ALIASES["context_retrieval_response"] + == ContextRetrievalResponseEvent.EVENT_TYPE + ) + + +def test_resolve_event_type_replaces_alias() -> None: + assert resolve_event_type("input") == InputEvent.EVENT_TYPE + + +def test_resolve_event_type_passes_through_custom() -> None: + assert resolve_event_type("my_custom_event") == "my_custom_event" + + +def test_clazz_aliases_are_strings_with_dots() -> None: + assert CLAZZ_ALIASES + for resource_type, lang_map in CLAZZ_ALIASES.items(): + assert isinstance(resource_type, ResourceType) + assert lang_map, f"empty lang map for {resource_type}" + for lang, bucket in lang_map.items(): + assert bucket, f"empty alias bucket for ({resource_type}, {lang})" + for alias, fqn in bucket.items(): + assert isinstance(alias, str) + assert isinstance(fqn, str) + assert "." in fqn, ( + f"alias {alias!r} -> {fqn!r} in ({resource_type}, {lang}) is " + "not a qualified name" + ) + + +def test_resolve_clazz_replaces_alias_per_resource_type() -> None: + # Same short name resolves differently per resource type + conn = resolve_clazz("ollama", ResourceType.CHAT_MODEL_CONNECTION) + setup = resolve_clazz("ollama", ResourceType.CHAT_MODEL) + embed_conn = resolve_clazz("ollama", ResourceType.EMBEDDING_MODEL_CONNECTION) + assert conn.endswith("OllamaChatModelConnection") + assert setup.endswith("OllamaChatModelSetup") + assert embed_conn.endswith("OllamaEmbeddingModelConnection") + + +def test_resolve_clazz_passes_through_fqn() -> None: + assert ( + resolve_clazz("my.custom.Klass", ResourceType.CHAT_MODEL) == "my.custom.Klass" + ) + + +def test_resolve_clazz_unknown_alias_passes_through() -> None: + assert resolve_clazz("nonexistent", ResourceType.CHAT_MODEL) == "nonexistent" + + +def test_clazz_aliases_have_per_language_buckets() -> None: + for resource_type, lang_map in CLAZZ_ALIASES.items(): + assert "python" in lang_map, f"missing python bucket for {resource_type}" + # Java bucket optional; some resource types are Python-only + for lang, bucket in lang_map.items(): + assert bucket, f"empty bucket for ({resource_type}, {lang})" + for alias, fqn in bucket.items(): + assert isinstance(alias, str) + assert isinstance(fqn, str) + assert "." in fqn + + +def test_resolve_clazz_dispatches_on_language() -> None: + py = resolve_clazz("ollama", ResourceType.CHAT_MODEL_CONNECTION, "python") + java = resolve_clazz("ollama", ResourceType.CHAT_MODEL_CONNECTION, "java") + assert "OllamaChatModelConnection" in py + assert "OllamaChatModelConnection" in java + # Java FQN starts with `org.apache.flink.agents` + assert java.startswith("org.apache.flink.agents") + assert py.startswith("flink_agents") + + +def test_resolve_clazz_default_language_is_python() -> None: + default = resolve_clazz("ollama", ResourceType.CHAT_MODEL_CONNECTION) + explicit = resolve_clazz("ollama", ResourceType.CHAT_MODEL_CONNECTION, "python") + assert default == explicit + + +def test_java_wrapper_clazz_table_covers_supported_types() -> None: + # The Python-side wrappers must exist for every cross-language type + expected = { + ResourceType.CHAT_MODEL_CONNECTION, + ResourceType.CHAT_MODEL, + ResourceType.EMBEDDING_MODEL_CONNECTION, + ResourceType.EMBEDDING_MODEL, + ResourceType.VECTOR_STORE, + } + assert set(JAVA_WRAPPER_CLAZZ.keys()) == expected + for fqn in JAVA_WRAPPER_CLAZZ.values(): + assert "." in fqn diff --git a/python/flink_agents/api/yaml/tests/test_loader.py b/python/flink_agents/api/yaml/tests/test_loader.py new file mode 100644 index 000000000..7569c9bd7 --- /dev/null +++ b/python/flink_agents/api/yaml/tests/test_loader.py @@ -0,0 +1,561 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from pathlib import Path + +import pytest + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.chat_message import MessageRole +from flink_agents.api.events.chat_event import ChatResponseEvent +from flink_agents.api.events.event import InputEvent +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.function import JavaFunction, PythonFunction +from flink_agents.api.prompts.prompt import LocalPrompt +from flink_agents.api.resource import ResourceDescriptor, ResourceName, ResourceType +from flink_agents.api.skills import Skills +from flink_agents.api.tools.function_tool import FunctionTool +from flink_agents.api.yaml.loader import build_agents, load_yaml, resolve_function +from flink_agents.api.yaml.tests.fixtures import loader_targets + +_FIXTURES = Path(__file__).parent / "fixtures" + +_TARGETS_MODULE = "flink_agents.api.yaml.tests.fixtures.loader_targets" + + +def test_resolve_function_python_with_module_attr() -> None: + func = resolve_function( + name="anything", function=f"{_TARGETS_MODULE}:increment" + ) + assert isinstance(func, PythonFunction) + assert func.module == _TARGETS_MODULE + assert func.qualname == "increment" + # still callable + assert func.as_callable() is loader_targets.increment + + +def test_resolve_function_python_with_class_method() -> None: + # ``module:Class.method`` — the right side becomes + # ``PythonFunction.qualname`` verbatim and ``as_callable`` does the + # ``Class.method`` split internally. + func = resolve_function( + name="bump", function=f"{_TARGETS_MODULE}:Counter.bump" + ) + assert isinstance(func, PythonFunction) + assert func.module == _TARGETS_MODULE + assert func.qualname == "Counter.bump" + assert func.as_callable() is loader_targets.Counter.bump + + +def test_resolve_function_no_function_fails() -> None: + with pytest.raises(ValueError, match="'function' is required"): + resolve_function(name="x", function=None) + + +def test_resolve_function_missing_colon_fails() -> None: + # The dotted form used to be valid; under the new ``:`` syntax it + # must be rejected so the user gets a clear "use module:qualname" + # hint instead of a deep import failure. + with pytest.raises(ValueError, match="module-or-class.:.qualname"): + resolve_function(name="x", function=f"{_TARGETS_MODULE}.increment") + + +def test_resolve_function_multiple_colons_fails() -> None: + with pytest.raises(ValueError, match="module-or-class.:.qualname"): + resolve_function(name="x", function="a:b:c") + + +def test_resolve_function_empty_module_fails() -> None: + with pytest.raises(ValueError, match="module-or-class.:.qualname"): + resolve_function(name="x", function=":increment") + + +def test_resolve_function_empty_qualname_fails() -> None: + with pytest.raises(ValueError, match="module-or-class.:.qualname"): + resolve_function(name="x", function=f"{_TARGETS_MODULE}:") + + +def test_resolve_function_missing_target_raises_importerror() -> None: + # PythonFunction loads lazily; trigger the import via as_callable(). + func = resolve_function( + name="x", + function=f"{_TARGETS_MODULE}:does_not_exist", + ) + with pytest.raises((ImportError, AttributeError)): + func.as_callable() + + +def test_build_agents_rejects_duplicate_agent_within_file(tmp_path: Path) -> None: + yaml_text = ( + "agents:\n" + " - name: dup\n" + " actions:\n" + " - name: increment\n" + f" function: {_TARGETS_MODULE}:increment\n" + " listen_to: [input]\n" + " - name: dup\n" + " actions:\n" + " - name: decrement\n" + f" function: {_TARGETS_MODULE}:decrement\n" + " listen_to: [input]\n" + ) + p = tmp_path / "dup.yaml" + p.write_text(yaml_text) + with pytest.raises(ValueError, match="dup"): + build_agents(p) + + +def test_build_agents_from_single_agent_yaml() -> None: + agents, shared_resources, shared_actions = build_agents( + _FIXTURES / "single_agent.yaml" + ) + assert list(agents) == ["incrementer"] + agent = agents["incrementer"] + assert isinstance(agent, Agent) + assert "increment" in agent.actions + events, func, config = agent.actions["increment"] + assert events == [InputEvent.EVENT_TYPE] + assert isinstance(func, PythonFunction) + assert func.qualname == "increment" + assert config is None + assert shared_resources == {t: {} for t in shared_resources} + assert shared_actions == {} + + +def test_build_agents_resolves_event_alias_and_clazz_alias() -> None: + agents, _, _ = build_agents(_FIXTURES / "with_descriptors.yaml") + agent = agents["chat_agent"] + + inc_events, _, _ = agent.actions["increment"] + dec_events, _, _ = agent.actions["decrement"] + assert inc_events == [InputEvent.EVENT_TYPE] + assert dec_events == [ChatResponseEvent.EVENT_TYPE] + + conn = agent.resources[ResourceType.CHAT_MODEL_CONNECTION]["ollama_conn"] + assert isinstance(conn, ResourceDescriptor) + expected_module, _, expected_class = ( + ResourceName.ChatModel.OLLAMA_CONNECTION.rpartition(".") + ) + assert conn.target_module == expected_module + assert conn.target_clazz == expected_class + assert conn.arguments == { + "base_url": "http://localhost:11434", + "request_timeout": 30, + } + + +def test_build_agents_loads_tools_and_prompts() -> None: + agents, _, _ = build_agents(_FIXTURES / "with_tools_and_prompts.yaml") + agent = agents["tool_agent"] + + tool = agent.resources[ResourceType.TOOL]["notify"] + assert isinstance(tool, FunctionTool) + assert isinstance(tool.func, PythonFunction) + assert tool.func.qualname == "notify" + + text_prompt = agent.resources[ResourceType.PROMPT]["text_prompt"] + assert isinstance(text_prompt, LocalPrompt) + assert text_prompt.template == "hello {name}" + + msg_prompt = agent.resources[ResourceType.PROMPT]["messages_prompt"] + assert isinstance(msg_prompt, LocalPrompt) + assert len(msg_prompt.template) == 2 + assert msg_prompt.template[0].role == MessageRole.SYSTEM + assert msg_prompt.template[1].content == "{q}" + + +def test_build_agents_handles_shared_resources_and_actions() -> None: + agents, shared_resources, shared_actions = build_agents( + _FIXTURES / "with_shared.yaml" + ) + + # shared resources surfaced to caller + assert "shared_conn" in shared_resources[ResourceType.CHAT_MODEL_CONNECTION] + # shared actions stored as ActionSpec for cross-agent reference resolution + assert "shared_inc" in shared_actions + + # both a1 and a2 own a copy of shared_inc after caller-side merge? + # NO — build_agents only handles in-file. The merge happens in load_yaml. + # Here we assert build_agents leaves string refs *unresolved* for the caller: + a1 = agents["a1"] + a2 = agents["a2"] + assert "shared_inc" not in a1.actions # not yet merged in + assert "own_dec" in a1.actions + assert "shared_inc" not in a2.actions + + +def test_load_yaml_registers_single_agent_on_env() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "single_agent.yaml") + assert "incrementer" in env._agents + + +def test_load_yaml_registers_multiple_agents() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "multi_agent.yaml") + assert set(env._agents.keys()) == {"a1", "a2"} + + +def test_load_yaml_merges_shared_action_into_agents() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "with_shared.yaml") + a1 = env._agents["a1"] + a2 = env._agents["a2"] + assert "shared_inc" in a1.actions + assert "shared_inc" in a2.actions + events_a1, func_a1, _ = a1.actions["shared_inc"] + events_a2, func_a2, _ = a2.actions["shared_inc"] + assert events_a1 == [InputEvent.EVENT_TYPE] + assert events_a2 == [InputEvent.EVENT_TYPE] + assert isinstance(func_a1, PythonFunction) + assert func_a1.qualname == "increment" + assert isinstance(func_a2, PythonFunction) + assert func_a2.qualname == "increment" + + +def test_load_yaml_registers_shared_resources_on_env() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "with_shared.yaml") + assert "shared_conn" in env.resources[ResourceType.CHAT_MODEL_CONNECTION] + + +def test_load_yaml_string_ref_to_missing_shared_action_errors(tmp_path: Path) -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + bad = tmp_path / "bad_missing_shared_action.yaml" + bad.write_text("agents:\n - name: a\n actions:\n - undefined_action\n") + with pytest.raises(ValueError, match="undefined_action"): + load_yaml(env, bad) + + +def test_load_yaml_multi_call_merges() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "multi_file_a.yaml") + load_yaml(env, _FIXTURES / "multi_file_b.yaml") + assert {"file_a_agent", "file_b_agent"} <= set(env._agents.keys()) + assert "conn_from_a" in env.resources[ResourceType.CHAT_MODEL_CONNECTION] + assert "conn_from_b" in env.resources[ResourceType.CHAT_MODEL_CONNECTION] + + +def test_load_yaml_accepts_list_of_paths() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, [_FIXTURES / "multi_file_a.yaml", _FIXTURES / "multi_file_b.yaml"]) + assert {"file_a_agent", "file_b_agent"} <= set(env._agents.keys()) + + +def test_load_yaml_duplicate_agent_across_calls_errors() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "multi_file_a.yaml") + with pytest.raises(ValueError, match="file_a_agent"): + load_yaml(env, _FIXTURES / "multi_file_a.yaml") + + +def test_load_yaml_duplicate_shared_resource_within_file_errors(tmp_path) -> None: + # In-file duplicate detection used to differ between ``build_agents`` + # (raise) and ``load_yaml`` (silent last-wins). Both entrypoints now + # go through the same builder, so ``load_yaml`` rejects too. + bad = tmp_path / "dup_in_file.yaml" + bad.write_text( + "agents:\n" + " - name: a\n" + "chat_model_connections:\n" + " - name: conn\n" + " clazz: x.Y\n" + " - name: conn\n" + " clazz: x.Z\n" + ) + env = AgentsExecutionEnvironment.get_execution_environment() + with pytest.raises(ValueError, match="Duplicate shared resource name 'conn'"): + load_yaml(env, bad) + + +def test_load_yaml_duplicate_shared_action_within_file_errors(tmp_path) -> None: + bad = tmp_path / "dup_action_in_file.yaml" + bad.write_text( + "agents:\n" + " - name: a\n" + "actions:\n" + " - name: shared\n" + " listen_to: [input]\n" + " - name: shared\n" + " listen_to: [input]\n" + ) + env = AgentsExecutionEnvironment.get_execution_environment() + with pytest.raises(ValueError, match="Duplicate shared action name 'shared'"): + load_yaml(env, bad) + + +def test_load_yaml_duplicate_shared_resource_across_calls_errors(tmp_path) -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "multi_file_a.yaml") + dup = tmp_path / "dup.yaml" + dup.write_text( + "agents:\n - name: other\n" + "chat_model_connections:\n" + " - name: conn_from_a\n" + " clazz: ollama\n" + ) + with pytest.raises(ValueError, match="conn_from_a"): + load_yaml(env, dup) + + +def test_apply_by_agent_name_runs_yaml_loaded_agent() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "single_agent.yaml") + + input_list = [] + output_list = env.from_list(input_list).apply("incrementer").to_list() + input_list.append({"key": "bob", "value": 1}) + input_list.append({"key": "john", "value": 2}) + env.execute() + assert output_list == [{"bob": 2}, {"john": 3}] + + +def test_apply_by_unknown_name_errors() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + with pytest.raises(ValueError, match="ghost"): + env.from_list([]).apply("ghost") + + +def test_build_agents_loads_skills_per_agent_and_shared() -> None: + agents, shared_resources, _ = build_agents(_FIXTURES / "with_skills.yaml") + agent = agents["skills_agent"] + + own = agent.resources[ResourceType.SKILLS]["agent_skills"] + assert isinstance(own, Skills) + assert own.paths == ["./agent_skill_dir"] + + shared = shared_resources[ResourceType.SKILLS]["shared_skills"] + assert isinstance(shared, Skills) + assert shared.paths == ["./shared_skill_dir", "./more"] + + +def test_load_yaml_registers_shared_skills_on_env() -> None: + env = AgentsExecutionEnvironment.get_execution_environment() + load_yaml(env, _FIXTURES / "with_skills.yaml") + shared = env.resources[ResourceType.SKILLS]["shared_skills"] + assert isinstance(shared, Skills) + assert shared.paths == ["./shared_skill_dir", "./more"] + + +def test_build_agents_supports_type_java(tmp_path: Path) -> None: + yaml_text = ( + "agents:\n" + " - name: a\n" + " chat_model_connections:\n" + " - name: java_conn\n" + " type: java\n" + " clazz: ollama\n" + " endpoint: http://localhost:11434\n" + " requestTimeout: 120\n" + ) + p = tmp_path / "java_resource.yaml" + p.write_text(yaml_text) + agents, _, _ = build_agents(p) + agent = agents["a"] + + conn = agent.resources[ResourceType.CHAT_MODEL_CONNECTION]["java_conn"] + # clazz is the Python-side Java wrapper + assert conn.target_clazz == "JavaChatModelConnection" + # java_clazz arg points at the Java implementation + assert ( + conn.arguments["java_clazz"] + == "org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection" + ) + # other kwargs flow through + assert conn.arguments["endpoint"] == "http://localhost:11434" + assert conn.arguments["requestTimeout"] == 120 + + +def test_build_agents_rejects_type_java_for_unsupported_resource( + tmp_path: Path, +) -> None: + # MCP_SERVER has no Python-side Java wrapper, so type=java must error. + yaml_text = ( + "agents:\n" + " - name: a\n" + " mcp_servers:\n" + " - name: x\n" + " type: java\n" + " clazz: anything\n" + ) + p = tmp_path / "bad_java.yaml" + p.write_text(yaml_text) + with pytest.raises(ValueError, match="java"): + build_agents(p) + + +def test_clazz_alias_resolves_per_section(tmp_path: Path) -> None: + yaml_text = ( + "agents:\n" + " - name: a\n" + " chat_model_connections:\n" + " - name: conn\n" + " clazz: ollama\n" + " base_url: http://x\n" + " chat_model_setups:\n" + " - name: setup\n" + " clazz: ollama\n" + " connection: conn\n" + " embedding_model_connections:\n" + " - name: e_conn\n" + " clazz: ollama\n" + " base_url: http://y\n" + ) + p = tmp_path / "per_section.yaml" + p.write_text(yaml_text) + agents, _, _ = build_agents(p) + agent = agents["a"] + + conn = agent.resources[ResourceType.CHAT_MODEL_CONNECTION]["conn"] + setup = agent.resources[ResourceType.CHAT_MODEL]["setup"] + e_conn = agent.resources[ResourceType.EMBEDDING_MODEL_CONNECTION]["e_conn"] + assert conn.target_clazz == "OllamaChatModelConnection" + assert setup.target_clazz == "OllamaChatModelSetup" + assert e_conn.target_clazz == "OllamaEmbeddingModelConnection" + + +def test_resolve_function_builds_java_function_for_java_language() -> None: + func = resolve_function( + name="firstAction", + function="com.example.MyAgent:firstAction", + language="java", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + assert isinstance(func, JavaFunction) + assert func.qualname == "com.example.MyAgent" + assert func.method_name == "firstAction" + assert func.parameter_types == [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ] + + +def test_resolve_function_java_supports_inner_classes() -> None: + func = resolve_function( + name="m", + function="com.example.Outer$Inner:m", + language="java", + parameter_types=[], + ) + assert isinstance(func, JavaFunction) + assert func.qualname == "com.example.Outer$Inner" + assert func.method_name == "m" + + +def test_resolve_function_python_is_default_language() -> None: + func1 = resolve_function( + name="x", function=f"{_TARGETS_MODULE}:increment" + ) + func2 = resolve_function( + name="x", + function=f"{_TARGETS_MODULE}:increment", + language="python", + ) + assert isinstance(func1, PythonFunction) + assert isinstance(func2, PythonFunction) + assert func1.module == func2.module + assert func1.qualname == func2.qualname + + +def test_build_agents_action_func_is_python_function() -> None: + agents, _, _ = build_agents(_FIXTURES / "single_agent.yaml") + agent = agents["incrementer"] + events, func, _ = agent.actions["increment"] + assert isinstance(func, PythonFunction) + assert func.qualname == "increment" + + +def test_build_agents_builds_java_action(tmp_path: Path) -> None: + yaml_text = ( + "agents:\n" + " - name: a\n" + " actions:\n" + " - name: a1\n" + " type: java\n" + " function: com.example.MyAgent:handle\n" + " listen_to: [input]\n" + ) + p = tmp_path / "java_action.yaml" + p.write_text(yaml_text) + agents, _, _ = build_agents(p) + agent = agents["a"] + _, func, _ = agent.actions["a1"] + assert isinstance(func, JavaFunction) + assert func.qualname == "com.example.MyAgent" + assert func.method_name == "handle" + assert func.parameter_types == [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ] + + +def test_build_agents_rejects_java_tool_missing_parameter_types( + tmp_path: Path, +) -> None: + yaml_text = ( + "agents:\n" + " - name: a\n" + " tools:\n" + " - name: t1\n" + " type: java\n" + " function: com.example.Tools:add\n" + " actions:\n" + " - name: noop\n" + f" function: {_TARGETS_MODULE}:increment\n" + " listen_to: [input]\n" + ) + p = tmp_path / "java_tool_no_params.yaml" + p.write_text(yaml_text) + with pytest.raises(ValueError, match="parameter_types"): + build_agents(p) + + +def test_build_agents_builds_java_tool_descriptor(tmp_path: Path) -> None: + """YAML parsing of a Java tool yields an api ``FunctionTool`` wrapping + a ``JavaFunction`` descriptor — no JVM needed at parse time. + + Metadata extraction (via py4j on the plan side) is wired up later; + see ``flink_agents.plan.tools.function_tool.FunctionTool.metadata`` + which currently raises ``NotImplementedError`` for Java tools. + """ + yaml_text = ( + "agents:\n" + " - name: a\n" + " tools:\n" + " - name: add\n" + " type: java\n" + " function: com.example.Tools:add\n" + " parameter_types: [int, int]\n" + " actions:\n" + " - name: noop\n" + f" function: {_TARGETS_MODULE}:increment\n" + " listen_to: [input]\n" + ) + p = tmp_path / "java_tool.yaml" + p.write_text(yaml_text) + agents, _, _ = build_agents(p) + agent = agents["a"] + + tool = agent.resources[ResourceType.TOOL]["add"] + assert isinstance(tool, FunctionTool) + assert isinstance(tool.func, JavaFunction) + assert tool.func.qualname == "com.example.Tools" + assert tool.func.method_name == "add" + assert tool.func.parameter_types == ["int", "int"] diff --git a/python/flink_agents/api/yaml/tests/test_specs.py b/python/flink_agents/api/yaml/tests/test_specs.py new file mode 100644 index 000000000..a7c89dbde --- /dev/null +++ b/python/flink_agents/api/yaml/tests/test_specs.py @@ -0,0 +1,328 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from flink_agents.api.yaml.specs import ( + ActionSpec, + AgentSpec, + DescriptorSpec, + MessageRole, + PromptMessage, + PromptSpec, + SkillsSpec, + ToolSpec, + YamlAgentsDocument, + export, +) + + +def test_descriptor_spec_requires_name_and_clazz() -> None: + with pytest.raises(ValidationError): + DescriptorSpec.model_validate({"clazz": "x.Y"}) + with pytest.raises(ValidationError): + DescriptorSpec.model_validate({"name": "n"}) + + +def test_descriptor_spec_passes_extras_through() -> None: + spec = DescriptorSpec.model_validate( + {"name": "n", "clazz": "x.Y", "base_url": "http://x", "timeout": 5} + ) + assert spec.name == "n" + assert spec.clazz == "x.Y" + assert spec.model_extra == {"base_url": "http://x", "timeout": 5} + + +def test_descriptor_spec_type_defaults_to_none() -> None: + spec = DescriptorSpec.model_validate({"name": "n", "clazz": "x.Y"}) + assert spec.type is None + + +def test_descriptor_spec_accepts_python_and_java() -> None: + py = DescriptorSpec.model_validate({"name": "n", "clazz": "x.Y", "type": "python"}) + java = DescriptorSpec.model_validate({"name": "n", "clazz": "x.Y", "type": "java"}) + assert py.type == "python" + assert java.type == "java" + + +def test_descriptor_spec_rejects_unknown_type() -> None: + with pytest.raises(ValidationError): + DescriptorSpec.model_validate({"name": "n", "clazz": "x.Y", "type": "go"}) + + +def test_message_role_values() -> None: + assert MessageRole.SYSTEM.value == "system" + assert MessageRole.USER.value == "user" + assert MessageRole.ASSISTANT.value == "assistant" + assert MessageRole.TOOL.value == "tool" + + +def test_prompt_message_defaults_to_user() -> None: + msg = PromptMessage.model_validate({"content": "hi"}) + assert msg.role == MessageRole.USER + assert msg.content == "hi" + + +def test_prompt_spec_with_text() -> None: + spec = PromptSpec.model_validate({"name": "p1", "text": "hello {x}"}) + assert spec.text == "hello {x}" + assert spec.messages is None + + +def test_prompt_spec_with_messages() -> None: + spec = PromptSpec.model_validate( + { + "name": "p1", + "messages": [ + {"role": "system", "content": "be nice"}, + {"role": "user", "content": "{input}"}, + ], + } + ) + assert spec.messages is not None + assert spec.messages[0].role == MessageRole.SYSTEM + assert spec.text is None + + +def test_prompt_spec_requires_text_xor_messages() -> None: + with pytest.raises(ValidationError): + PromptSpec.model_validate({"name": "p1"}) + with pytest.raises(ValidationError): + PromptSpec.model_validate( + {"name": "p1", "text": "x", "messages": [{"content": "y"}]} + ) + + +def test_prompt_spec_rejects_empty_text_or_messages() -> None: + # ``text: ""`` and ``messages: []`` used to slip past the "exactly one" + # check because the prior implementation tested ``is None`` rather than + # truthiness. Either alone, or together, must now be rejected so an + # empty prompt cannot be built silently. + with pytest.raises(ValidationError): + PromptSpec.model_validate({"name": "p1", "messages": []}) + with pytest.raises(ValidationError): + PromptSpec.model_validate({"name": "p1", "text": ""}) + with pytest.raises(ValidationError): + PromptSpec.model_validate({"name": "p1", "text": "", "messages": []}) + + +def test_tool_spec_name_only() -> None: + spec = ToolSpec.model_validate({"name": "t1"}) + assert spec.name == "t1" + assert spec.function is None + + +def test_tool_spec_with_function() -> None: + spec = ToolSpec.model_validate({"name": "t1", "function": "m.f"}) + assert spec.function == "m.f" + + +def test_tool_spec_forbids_extras() -> None: + with pytest.raises(ValidationError): + ToolSpec.model_validate({"name": "t1", "unknown": 1}) + + +def test_action_spec_requires_listen_to() -> None: + with pytest.raises(ValidationError): + ActionSpec.model_validate({"name": "a1"}) + + +def test_action_spec_rejects_empty_listen_to() -> None: + # An empty ``listen_to: []`` would silently register a dead action that + # never fires. The minimum-length constraint forces the mistake to + # surface at YAML validation time. + with pytest.raises(ValidationError): + ActionSpec.model_validate({"name": "a1", "listen_to": []}) + + +def test_action_spec_defaults() -> None: + spec = ActionSpec.model_validate({"name": "a1", "listen_to": ["input"]}) + assert spec.listen_to == ["input"] + assert spec.function is None + assert spec.config is None + + +def test_action_spec_with_config() -> None: + spec = ActionSpec.model_validate( + {"name": "a1", "listen_to": ["input"], "config": {"k": 1}} + ) + assert spec.config == {"k": 1} + + +def test_action_spec_accepts_type() -> None: + spec = ActionSpec.model_validate( + {"name": "a1", "listen_to": ["input"], "type": "java"} + ) + assert spec.type == "java" + + +def test_action_spec_type_defaults_to_none() -> None: + spec = ActionSpec.model_validate({"name": "a1", "listen_to": ["input"]}) + assert spec.type is None + + +def test_action_spec_rejects_unknown_type() -> None: + with pytest.raises(ValidationError): + ActionSpec.model_validate( + {"name": "a1", "listen_to": ["input"], "type": "rust"} + ) + + +def test_tool_spec_accepts_type() -> None: + spec = ToolSpec.model_validate({"name": "t1", "type": "java"}) + assert spec.type == "java" + + +def test_tool_spec_type_defaults_to_none() -> None: + spec = ToolSpec.model_validate({"name": "t1"}) + assert spec.type is None + + +def test_agent_spec_requires_name() -> None: + with pytest.raises(ValidationError): + AgentSpec.model_validate({}) + + +def test_agent_spec_minimal() -> None: + spec = AgentSpec.model_validate({"name": "a"}) + assert spec.name == "a" + assert spec.description is None + assert spec.actions == [] + assert spec.prompts == [] + assert spec.tools == [] + assert spec.chat_model_connections == [] + + +def test_agent_spec_action_can_be_string_reference() -> None: + spec = AgentSpec.model_validate( + { + "name": "a", + "actions": [ + "shared_action1", + {"name": "x", "listen_to": ["input"]}, + ], + } + ) + assert spec.actions[0] == "shared_action1" + assert isinstance(spec.actions[1], ActionSpec) + + +def test_yaml_document_requires_agents() -> None: + with pytest.raises(ValidationError): + YamlAgentsDocument.model_validate({}) + + +def test_yaml_document_minimal() -> None: + doc = YamlAgentsDocument.model_validate({"agents": [{"name": "a"}]}) + assert len(doc.agents) == 1 + assert doc.agents[0].name == "a" + assert doc.chat_model_connections == [] + assert doc.actions == [] + + +def test_yaml_document_with_shared_resources_and_actions() -> None: + doc = YamlAgentsDocument.model_validate( + { + "agents": [{"name": "a"}], + "chat_model_connections": [{"name": "c", "clazz": "x.Y"}], + "actions": [{"name": "shared", "listen_to": ["input"]}], + } + ) + assert doc.chat_model_connections[0].name == "c" + assert isinstance(doc.actions[0], ActionSpec) + assert doc.actions[0].name == "shared" + + +def test_skills_spec_requires_paths() -> None: + with pytest.raises(ValidationError): + SkillsSpec.model_validate({"name": "s"}) + + +def test_skills_spec_with_paths() -> None: + spec = SkillsSpec.model_validate({"name": "s", "paths": ["./a", "./b"]}) + assert spec.paths == ["./a", "./b"] + + +def test_skills_spec_forbids_extras() -> None: + with pytest.raises(ValidationError): + SkillsSpec.model_validate({"name": "s", "paths": ["./a"], "extra": 1}) + + +def test_agent_spec_has_skills_field() -> None: + spec = AgentSpec.model_validate({"name": "a"}) + assert spec.skills == [] + + +def test_yaml_document_has_skills_field() -> None: + doc = YamlAgentsDocument.model_validate({"agents": [{"name": "a"}]}) + assert doc.skills == [] + + +def test_yaml_document_and_agent_reject_events() -> None: + # ``events:`` declarations used to be accepted silently and dropped by + # the loader, at both the document level and the agent level. Both + # levels now forbid the key outright so the mistake surfaces at + # validation time. + with pytest.raises(ValidationError): + YamlAgentsDocument.model_validate( + {"agents": [{"name": "a"}], "events": [{"name": "evt"}]} + ) + with pytest.raises(ValidationError): + AgentSpec.model_validate({"name": "a", "events": [{"name": "evt"}]}) + + +_SCHEMA_FILE = Path(__file__).parents[5] / "docs" / "yaml-schema.json" + + +def test_action_spec_rejects_parameter_types() -> None: + # Action signatures are fixed; parameter_types is not exposed. + import pytest + + with pytest.raises(ValueError, match="parameter_types"): + ActionSpec.model_validate( + { + "name": "a1", + "listen_to": ["input"], + "type": "java", + "parameter_types": ["x.Y"], + } + ) + + +def test_tool_spec_accepts_parameter_types() -> None: + spec = ToolSpec.model_validate( + {"name": "t1", "type": "java", "parameter_types": ["a.B", "a.C"]} + ) + assert spec.parameter_types == ["a.B", "a.C"] + + +def test_tool_spec_parameter_types_defaults_to_none() -> None: + spec = ToolSpec.model_validate({"name": "t1"}) + assert spec.parameter_types is None + + +def test_checked_in_schema_matches_pydantic_models() -> None: + on_disk = _SCHEMA_FILE.read_text() + fresh = export() + assert on_disk == fresh, ( + "docs/yaml-schema.json is out of sync with Pydantic models. " + "Run: python -m flink_agents.api.yaml.specs " + "> docs/yaml-schema.json" + ) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test.py new file mode 100644 index 000000000..cb9ebd81c --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test.py @@ -0,0 +1,216 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""E2E test: parse a single YAML file (chat model + function tool), +load it via :func:`AgentsExecutionEnvironment.load_yaml`, and run the +declared agent through the Flink remote runner. + +Uses Ollama for the chat backend — the YAML's ``model`` field is +hardcoded to ``qwen3:1.7b`` (same default as the other Ollama e2e +tests). Skipped when the Ollama client/model is not available. +""" + +import json +import os +import sysconfig +from pathlib import Path + +import pytest +from pyflink.common import Configuration, Encoder, WatermarkStrategy +from pyflink.common.typeinfo import Types +from pyflink.datastream import ( + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions import ( + YamlChatInput, + YamlChatKeySelector, + YamlChatOutput, +) +from flink_agents.e2e_tests.test_utils import pull_model + +current_dir = Path(__file__).parent +_RESOURCES = current_dir.parent / "resources" + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + +_OLLAMA_MODEL = "qwen3:1.7b" +_client = pull_model(_OLLAMA_MODEL) + + +@pytest.mark.skipif( + _client is None, + reason="Ollama client is not available or test model is missing.", +) +def test_single_yaml_agent(tmp_path: Path) -> None: + """``load_yaml`` → ``apply(by name)`` through the Flink remote + runner, exercising both the tool-using math chat model and the + plain creative chat model declared in the same YAML. + """ + config = Configuration() + config.set_string("python.pythonpath", sysconfig.get_paths()["purelib"]) + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{_RESOURCES}/yaml_test_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="yaml_test_source", + ) + + deserialize_datastream = input_datastream.map( + lambda x: YamlChatInput.model_validate_json(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + agents_env.load_yaml(_RESOURCES / "yaml_test_agent.yaml") + + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=YamlChatKeySelector() + ) + .apply("yaml_test_agent") + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + output_datastream.map(lambda x: json.dumps(x), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + answers = _read_answers(result_dir) + # Math path went through the ``add`` tool: the model's final + # answer should mention ``3``. Creative path should mention cats. + assert "3" in answers[1], f"math answer missing '3': {answers[1]!r}" + assert "cat" in answers[2].lower(), f"creative answer missing 'cat': {answers[2]!r}" + + +@pytest.mark.skipif( + _client is None, + reason="Ollama client is not available or test model is missing.", +) +def test_chained_yaml_agents(tmp_path: Path) -> None: + """One YAML file declares ``math_agent`` and ``commentator_agent``; + both reuse a file-level ``ollama_connection`` and the file-level + ``process_chat_response`` action. + + The two agents register on the environment via ``load_yaml`` and + run as a single chained Flink pipeline: + + FileSource → math_agent → commentator_agent → StreamingFileSink + + The math agent's output ``DataStream`` is fed straight into the + commentator agent — same job, same ``agents_env.execute()`` — so + the test exercises chaining two YAML-loaded agents end-to-end, + proves the file-level shared connection + shared action are + reusable across both agents, and asserts the math digit survives + the second LLM hop. + """ + config = Configuration() + config.set_string("python.pythonpath", sysconfig.get_paths()["purelib"]) + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + agents_env.load_yaml(_RESOURCES / "yaml_multi_agent.yaml") + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{_RESOURCES}/yaml_test_math_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="yaml_chained_source", + ) + deserialize_datastream = input_datastream.map( + lambda x: YamlChatInput.model_validate_json(x) + ) + + # Stage 1: math_agent answers the question. Stage 2 reads the + # YamlChatOutput stream straight from stage 1 and calls a second + # chat model — both stages reuse the file-level ``ollama_connection`` + # and the file-level ``process_chat_response`` action. + math_output = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=YamlChatKeySelector() + ) + .apply("math_agent") + .to_datastream() + ) + # ``to_datastream`` serialises ``OutputEvent.output`` into a plain + # dict at the Flink boundary, so re-validate into the pydantic + # model before feeding stage 2 (so the key selector and the + # ``commentary_request`` action both see a typed ``YamlChatOutput``). + math_output_typed = math_output.map(lambda x: YamlChatOutput.model_validate(x)) + final_output = ( + agents_env.from_datastream( + input=math_output_typed, key_selector=YamlChatKeySelector() + ) + .apply("commentator_agent") + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + final_output.map(lambda x: json.dumps(x), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + answers = _read_answers(result_dir) + final_answer = answers[1] + assert "3" in final_answer, ( + f"math result missing from chained output: {final_answer!r}" + ) + + +def _read_answers(result_dir: Path) -> dict[int, str]: + """Collect ``{id: answer}`` from every JSON line under ``result_dir``.""" + answers: dict[int, str] = {} + for path in result_dir.rglob("*"): + if not path.is_file(): + continue + for line in path.read_text().splitlines(): + line = line.strip() + if not line: + continue + record = json.loads(line) + answers[record["id"]] = record["answer"] + return answers diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test_actions.py b/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test_actions.py new file mode 100644 index 000000000..afdac2878 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/yaml_test_actions.py @@ -0,0 +1,156 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Functions referenced by ``resources/yaml_test_agent.yaml``. + +Each action and tool entry in the YAML points its ``function:`` at one +of the callables in this module by fully-qualified dotted path. +""" + +from pydantic import BaseModel +from pyflink.datastream import KeySelector + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext + + +class YamlChatInput(BaseModel): + """Input record: a question routed to a chat model.""" + + id: int + text: str + + +class YamlChatOutput(BaseModel): + """Output record: the chat model's textual answer.""" + + id: int + answer: str + + +class YamlChatKeySelector(KeySelector): + """KeySelector partitioning records by their ``id`` attribute. + + Works for both ``YamlChatInput`` (upstream of the math agent) and + ``YamlChatOutput`` (the math agent's output, piped into a downstream + agent like ``formatter_agent``). + """ + + def get_key(self, value: "YamlChatInput | YamlChatOutput") -> int: + """Use the record id as the partition key.""" + return value.id + + +def add(a: int, b: int) -> int: + """Calculate the sum of a and b. + + Parameters + ---------- + a : int + The first operand + b : int + The second operand + + Returns: + ------- + int: + The sum of a and b + """ + return a + b + + +def process_input(event: Event, ctx: RunnerContext) -> None: + """Route the incoming text to the math or creative chat model. + + The math model has access to the ``add`` tool; the creative model + does not. Routing is a simple keyword check on the input. The input + record's ``id`` is stashed in short-term memory so + ``process_chat_response`` can attach it back to the output. + """ + data = YamlChatInput.model_validate(InputEvent.from_event(event).input) + ctx.short_term_memory.set("input_id", data.id) + lower = data.text.lower() + model_name = ( + "math_chat_model" + if ("calculate" in lower or "sum" in lower) + else "creative_chat_model" + ) + ctx.send_event( + ChatRequestEvent( + model=model_name, + messages=[ChatMessage(role=MessageRole.USER, content=data.text)], + ) + ) + + +def chat_request(event: Event, ctx: RunnerContext) -> None: + """Send the input text to the agent-local ``chat_model``. + + Used by the multi-agent YAML, where each agent declares its own + ``chat_model`` (math one with the ``add`` tool, creative one + without) and the action simply forwards the user message. + """ + data = YamlChatInput.model_validate(InputEvent.from_event(event).input) + ctx.short_term_memory.set("input_id", data.id) + ctx.send_event( + ChatRequestEvent( + model="chat_model", + messages=[ChatMessage(role=MessageRole.USER, content=data.text)], + ) + ) + + +def process_chat_response(event: Event, ctx: RunnerContext) -> None: + """Emit the model's text response, tagged with the original input id.""" + chat_response = ChatResponseEvent.from_event(event) + response = chat_response.response + if not response or not response.content: + return + input_id = ctx.short_term_memory.get("input_id") + ctx.send_event( + OutputEvent(output=YamlChatOutput(id=input_id, answer=response.content)) + ) + + +def commentary_request(event: Event, ctx: RunnerContext) -> None: + """Stage-2 action: feed the upstream answer to a second chat model. + + The upstream record is a ``YamlChatOutput`` produced by the math + agent. We prompt the model to restate the same answer — the test + only needs the chain to actually pass through stage 2 (verifiable + by the math digit surviving the second LLM hop). Stashes the id in + short-term memory so the shared ``process_chat_response`` action + can re-attach it. + """ + data = YamlChatOutput.model_validate(InputEvent.from_event(event).input) + ctx.short_term_memory.set("input_id", data.id) + ctx.send_event( + ChatRequestEvent( + model="chat_model", + messages=[ + ChatMessage( + role=MessageRole.USER, + content=( + "Here is a math answer from another assistant: " + f"{data.answer!r}. Reply with the numeric result only." + ), + ) + ], + ) + ) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_actions.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_actions.py new file mode 100644 index 000000000..d8a891683 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_actions.py @@ -0,0 +1,54 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Python actions for ``resources/yaml_cross_language_agent.yaml``. + +The YAML declares a Python Ollama chat model bound to a **Java** +function tool (``calculateBMI`` on the Java cross-language agent). +These actions route input to the math (Java-tool-equipped) or creative +chat model and emit the model's final reply as an ``OutputEvent``. +""" + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext + + +def process_input(event: Event, ctx: RunnerContext) -> None: + """Route the input to math (Java tool) or creative chat model.""" + text = str(InputEvent.from_event(event).input) + lower = text.lower() + model_name = ( + "math_chat_model" + if ("calculate" in lower or "bmi" in lower) + else "creative_chat_model" + ) + ctx.send_event( + ChatRequestEvent( + model=model_name, + messages=[ChatMessage(role=MessageRole.USER, content=text)], + ) + ) + + +def process_chat_response(event: Event, ctx: RunnerContext) -> None: + """Emit the model's textual response.""" + chat_response = ChatResponseEvent.from_event(event) + response = chat_response.response + if response and response.content: + ctx.send_event(OutputEvent(output=response.content)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py new file mode 100644 index 000000000..ae9920f4e --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py @@ -0,0 +1,159 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""E2E test: a YAML-declared agent whose function tool is a Java method. + +Parses ``resources/yaml_cross_language_agent.yaml`` via +``AgentsExecutionEnvironment.load_yaml`` and runs the declared agent +through the Flink remote runner. The math chat model resolves a tool +named ``calculateBMI`` that is backed by the Java static method +``org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent.calculateBMI``, +exercising the cross-language path from a Python chat model to a Java +function tool. + +Skipped when the Ollama client/model is not available. +""" + +import os +import sysconfig +from pathlib import Path + +import pytest +from pyflink.common import Configuration, Encoder, WatermarkStrategy +from pyflink.common.typeinfo import Types +from pyflink.datastream import ( + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.test_utils import pull_model + +current_dir = Path(__file__).parent +_RESOURCES = current_dir.parent / "resources" + +# Locate the Java test-jar produced by the same e2e module. It ships the +# ``@Tool``-annotated static methods declared in the module's +# ``src/test/java`` (e.g. ``ChatModelCrossLanguageAgent.calculateBMI``). +# Building this jar is opted-in by the ``maven-jar-plugin`` ``test-jar`` +# execution in the module's ``pom.xml``; the Python test skips itself when +# the jar isn't present (i.e. the user hasn't run ``mvn package`` yet). +_REPO_ROOT = current_dir.parent.parent.parent.parent +_TEST_JAR = ( + _REPO_ROOT + / "e2e-test" + / "flink-agents-end-to-end-tests-resource-cross-language" + / "target" + / "flink-agents-end-to-end-tests-resource-cross-language-0.3-SNAPSHOT-tests.jar" +) + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + +OLLAMA_MODEL = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b") +os.environ["OLLAMA_CHAT_MODEL"] = OLLAMA_MODEL + +_client = pull_model(OLLAMA_MODEL) + + +@pytest.mark.skipif( + _client is None, + reason="Ollama client is not available or test model is missing.", +) +@pytest.mark.skipif( + not _TEST_JAR.is_file(), + reason=( + "Cross-language test-jar is missing; run " + "'mvn package -DskipTests -pl e2e-test/" + "flink-agents-end-to-end-tests-resource-cross-language' first." + ), +) +def test_yaml_cross_language_agent(tmp_path: Path) -> None: + """``load_yaml`` → ``apply(by name)`` with a YAML-declared Java tool. + + Exercises a Python Ollama chat model that calls a Java + ``calculateBMI`` tool declared in YAML and resolved against the + cross-language test JAR. + """ + config = Configuration() + config.set_string("python.pythonpath", sysconfig.get_paths()["purelib"]) + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + # Make the Java ``@Tool`` static methods declared in this module's + # ``src/test/java`` visible to the Flink classpath. + env.add_jars(f"file://{_TEST_JAR}") + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{_RESOURCES}/yaml_cross_language_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="yaml_cross_language_source", + ) + + deserialize_datastream = input_datastream.map(lambda x: str(x)) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + agents_env.load_yaml(_RESOURCES / "yaml_cross_language_agent.yaml") + + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=lambda x: "orderKey" + ) + .apply("yaml_cross_language_agent") + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + ( + output_datastream.map( + lambda x: str(x).replace("\n", "").replace("\r", ""), Types.STRING() + ).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + ) + + agents_env.execute() + + actual_result = [] + for file in result_dir.iterdir(): + if file.is_dir(): + for child in file.iterdir(): + with child.open() as f: + actual_result.extend(f.readlines()) + if file.is_file(): + with file.open() as f: + actual_result.extend(f.readlines()) + + # Math path went through the Java ``calculateBMI`` tool: + # 70 / (1.75 * 1.75) ≈ 22.86, so the final answer should mention 22. + assert "22" in actual_result[0], f"math answer missing '22': {actual_result[0]!r}" + # Creative path doesn't use any tool. + assert "cat" in actual_result[1].lower(), ( + f"creative answer missing 'cat': {actual_result[1]!r}" + ) diff --git a/python/flink_agents/e2e_tests/resources/yaml_cross_language_agent.yaml b/python/flink_agents/e2e_tests/resources/yaml_cross_language_agent.yaml new file mode 100644 index 000000000..decebe129 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_cross_language_agent.yaml @@ -0,0 +1,48 @@ +agents: + - name: yaml_cross_language_agent + description: | + YAML-driven cross-language e2e agent. + + - math path: Python Ollama chat model calling a Java function tool + (``calculateBMI``) — exercises the Python→Java tool bridge. + - creative path: Java Ollama chat model with no tools — exercises + the Python→Java chat-model resource bridge (``type: java`` on + both the connection and the setup). + + chat_model_connections: + - name: ollama_connection + clazz: ollama + request_timeout: 240.0 + - name: ollama_connection_java + clazz: ollama + type: java + endpoint: http://localhost:11434 + requestTimeout: 240 + + chat_model_setups: + - name: math_chat_model + clazz: ollama + connection: ollama_connection + model: qwen3:1.7b + tools: [calculateBMI] + extract_reasoning: true + - name: creative_chat_model + clazz: ollama + type: java + connection: ollama_connection_java + model: qwen3:1.7b + extract_reasoning: true + + tools: + - name: calculateBMI + type: java + function: org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent:calculateBMI + parameter_types: [java.lang.Double, java.lang.Double] + + actions: + - name: process_input + function: flink_agents.e2e_tests.e2e_tests_resource_cross_language.yaml_cross_language_actions:process_input + listen_to: [input] + - name: process_chat_response + function: flink_agents.e2e_tests.e2e_tests_resource_cross_language.yaml_cross_language_actions:process_chat_response + listen_to: [chat_response] \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/yaml_cross_language_input/input.txt b/python/flink_agents/e2e_tests/resources/yaml_cross_language_input/input.txt new file mode 100644 index 000000000..b85035e19 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_cross_language_input/input.txt @@ -0,0 +1,2 @@ +Calculate BMI for someone who is 1.75 meters tall and weighs 70 kg. +Tell me a joke about cats. diff --git a/python/flink_agents/e2e_tests/resources/yaml_multi_agent.yaml b/python/flink_agents/e2e_tests/resources/yaml_multi_agent.yaml new file mode 100644 index 000000000..06d0138ac --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_multi_agent.yaml @@ -0,0 +1,46 @@ +agents: + - name: math_agent + description: Stage 1 — solves the math question via the ``add`` tool. + chat_model_setups: + - name: chat_model + clazz: ollama + connection: ollama_connection + model: qwen3:1.7b + tools: [add] + extract_reasoning: true + tools: + - name: add + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:add + actions: + - name: chat_request + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:chat_request + listen_to: [input] + - process_chat_response + + - name: commentator_agent + description: | + Stage 2 — takes the upstream answer and asks a second chat model + to restate it. Reuses the file-level ``ollama_connection`` and + the file-level ``process_chat_response`` action. + chat_model_setups: + - name: chat_model + clazz: ollama + connection: ollama_connection + model: qwen3:1.7b + extract_reasoning: true + actions: + - name: commentary_request + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:commentary_request + listen_to: [input] + - process_chat_response + +# File-level shared resources reused by every agent above. +chat_model_connections: + - name: ollama_connection + clazz: ollama + request_timeout: 240.0 + +actions: + - name: process_chat_response + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:process_chat_response + listen_to: [chat_response] \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/yaml_test_agent.yaml b/python/flink_agents/e2e_tests/resources/yaml_test_agent.yaml new file mode 100644 index 000000000..6c29f5089 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_test_agent.yaml @@ -0,0 +1,32 @@ +agents: + - name: yaml_test_agent + description: YAML-driven e2e agent — chat model with a function tool. + + chat_model_connections: + - name: ollama_connection + clazz: ollama + request_timeout: 240.0 + + chat_model_setups: + - name: math_chat_model + clazz: ollama + connection: ollama_connection + model: qwen3:1.7b + tools: [add] + extract_reasoning: true + - name: creative_chat_model + clazz: ollama + connection: ollama_connection + model: qwen3:1.7b + extract_reasoning: true + + tools: + - name: add + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:add + actions: + - name: process_input + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:process_input + listen_to: [input] + - name: process_chat_response + function: flink_agents.e2e_tests.e2e_tests_integration.yaml_test_actions:process_chat_response + listen_to: [chat_response] \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/yaml_test_input/input.txt b/python/flink_agents/e2e_tests/resources/yaml_test_input/input.txt new file mode 100644 index 000000000..a2ae71831 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_test_input/input.txt @@ -0,0 +1,2 @@ +{"id": 1, "text": "calculate the sum of 1 and 2."} +{"id": 2, "text": "Tell me a joke about cats."} diff --git a/python/flink_agents/e2e_tests/resources/yaml_test_math_input/input.txt b/python/flink_agents/e2e_tests/resources/yaml_test_math_input/input.txt new file mode 100644 index 000000000..f47a8e606 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/yaml_test_math_input/input.txt @@ -0,0 +1 @@ +{"id": 1, "text": "calculate the sum of 1 and 2."} diff --git a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py index 247759a23..9741e054a 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py @@ -23,12 +23,12 @@ from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext +from flink_agents.api.tools.tool import Tool from flink_agents.integrations.chat_models.anthropic.anthropic_chat_model import ( DEFAULT_ANTHROPIC_MODEL, AnthropicChatModelConnection, AnthropicChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable test_model = os.environ.get("TEST_MODEL") api_key = os.environ.get("TEST_API_KEY") @@ -84,7 +84,7 @@ def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return Tool.from_callable(func=add) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py index 30753d755..79d2fb5c1 100644 --- a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py @@ -27,7 +27,8 @@ AzureOpenAIChatModelConnection, AzureOpenAIChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_deployment = os.environ.get("TEST_AZURE_DEPLOYMENT") api_key = os.environ.get("AZURE_OPENAI_API_KEY") @@ -95,7 +96,7 @@ def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py index 2280bf44d..7ccb6c225 100644 --- a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py @@ -28,7 +28,8 @@ OpenAIChatModelConnection, OpenAIChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("TEST_MODEL") api_key = os.environ.get("TEST_API_KEY") @@ -86,7 +87,7 @@ def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py index 5503185bf..6a2a47117 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -31,7 +31,8 @@ OllamaChatModelConnection, OllamaChatModelSetup, ) -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b") current_dir = Path(__file__).parent @@ -90,7 +91,7 @@ def add(a: int, b: int) -> int: def get_tool(name: str, type: ResourceType) -> FunctionTool: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) @pytest.mark.skipif( diff --git a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py index fc80ac6a0..c33a792ce 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py @@ -29,7 +29,8 @@ TongyiChatModelConnection, TongyiChatModelSetup, ) -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("TONGYI_CHAT_MODEL", "qwen-plus") api_key_available = "DASHSCOPE_API_KEY" in os.environ @@ -68,7 +69,7 @@ def add(a: int, b: int) -> int: def get_tool(name: str, type: ResourceType) -> FunctionTool: """Helper function to create a tool for testing.""" - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) @pytest.mark.skipif(not api_key_available, reason="DashScope API key is not set") diff --git a/python/flink_agents/integrations/mcp/tests/test_mcp.py b/python/flink_agents/integrations/mcp/tests/test_mcp.py index 7a6e14877..1bb81ce8a 100644 --- a/python/flink_agents/integrations/mcp/tests/test_mcp.py +++ b/python/flink_agents/integrations/mcp/tests/test_mcp.py @@ -26,7 +26,8 @@ from pydantic import AnyUrl from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.integrations.mcp.mcp import MCPServer +from flink_agents.api.tools.tool import ToolMetadata +from flink_agents.integrations.mcp.mcp import MCPServer, MCPTool def run_server() -> None: @@ -124,3 +125,27 @@ def test_serialize_mcp_server() -> None: deserialized.auth.context.client_metadata == mcp_server.auth.context.client_metadata ) + + +def test_mcp_tool_roundtrip_preserves_metadata() -> None: + metadata = ToolMetadata( + name="add", + description="Add two integers.", + args_schema={ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + ) + tool = MCPTool(metadata=metadata, mcp_server=MCPServer(endpoint="http://x")) + + dumped = tool.model_dump() + assert "metadata" in dumped, "serialized form must expose `metadata` key" + assert "metadata_" not in dumped + + restored = MCPTool.model_validate(dumped) + assert restored.metadata == metadata + assert restored.name == "add" diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index f520ee63d..f38c0fc77 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -20,6 +20,9 @@ from pydantic import BaseModel, field_serializer, model_validator from flink_agents.api.agents.agent import Agent +from flink_agents.api.function import Function as ApiFunction +from flink_agents.api.function import JavaFunction as ApiJavaFunction +from flink_agents.api.function import PythonFunction as ApiPythonFunction from flink_agents.api.resource import ( ResourceDescriptor, ResourceType, @@ -30,12 +33,14 @@ LOAD_SKILL_TOOL, Skills, ) +from flink_agents.api.tools.function_tool import FunctionTool as ApiFunctionTool +from flink_agents.api.tools.tool import Tool from flink_agents.plan.actions.action import Action from flink_agents.plan.actions.chat_model_action import CHAT_MODEL_ACTION from flink_agents.plan.actions.context_retrieval_action import CONTEXT_RETRIEVAL_ACTION from flink_agents.plan.actions.tool_call_action import TOOL_CALL_ACTION from flink_agents.plan.configuration import AgentConfiguration -from flink_agents.plan.function import PythonFunction +from flink_agents.plan.function import JavaFunction, PythonFunction from flink_agents.plan.resource_provider import ( JavaResourceProvider, JavaSerializableResourceProvider, @@ -43,7 +48,7 @@ PythonSerializableResourceProvider, ResourceProvider, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.tools.function_tool import FunctionTool if TYPE_CHECKING: from flink_agents.api.resource import ( @@ -262,7 +267,7 @@ def _get_actions(agent: Agent) -> List[Action]: actions.append( Action( name=name, - exec=PythonFunction.from_callable(action_tuple[1]), + exec=_to_plan_function(action_tuple[1]), listen_event_types=[ _resolve_event_type(et) for et in action_tuple[0] @@ -273,6 +278,27 @@ def _get_actions(agent: Agent) -> List[Action]: return actions +def _to_plan_function(func: ApiFunction) -> PythonFunction | JavaFunction: + """Promote an api Function descriptor to its executable plan counterpart. + + Agent stores api-layer descriptors (pure data). Action.exec needs the + plan-layer executable variants for ``check_signature`` and + ``__call__``, so we rebuild here. + """ + if isinstance(func, ApiPythonFunction): + return PythonFunction(module=func.module, qualname=func.qualname) + if isinstance(func, ApiJavaFunction): + return JavaFunction( + qualname=func.qualname, + method_name=func.method_name, + parameter_types=list(func.parameter_types), + ) + msg = f"Unsupported function descriptor: {type(func).__name__}" + raise TypeError(msg) + + + + def _get_resource_providers( agent: Agent, config: AgentConfiguration ) -> List[ResourceProvider]: @@ -307,10 +333,11 @@ def _get_resource_providers( if callable(value): # TODO: support other tool type. - tool = from_callable(func=value) + tool = Tool.from_callable(func=value) resource_providers.append( PythonSerializableResourceProvider.from_resource( - name=name, resource=tool + name=name, + resource=FunctionTool(func=_to_plan_function(tool.func)), ) ) elif hasattr(value, "_is_prompt"): @@ -342,7 +369,12 @@ def _get_resource_providers( for name, tool in agent.resources[ResourceType.TOOL].items(): resource_providers.append( PythonSerializableResourceProvider.from_resource( - name=name, resource=from_callable(tool.func) + name=name, + resource=( + FunctionTool(func=_to_plan_function(tool.func)) + if isinstance(tool, ApiFunctionTool) + else tool + ), ) ) diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index d8c10a4ba..55086414a 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -22,7 +22,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Tuple, get_type_hints -from pydantic import BaseModel, model_serializer +from pydantic import BaseModel, PrivateAttr, model_serializer from flink_agents.plan.utils import check_type_match @@ -102,7 +102,7 @@ def _is_function_cacheable(func: Callable) -> bool: class Function(BaseModel, ABC): - """Base interface for user defined functions, includes python and java.""" + """Base interface for user-defined functions.""" @abstractmethod def check_signature(self, *args: Tuple[Any, ...]) -> None: @@ -216,6 +216,10 @@ def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: """ return self.__get_func()(*args, **kwargs) + def as_callable(self) -> Callable: + """Return the underlying Python callable, importing the module if needed.""" + return self.__get_func() + def __get_func(self) -> Callable: if self.__func is None: module = importlib.import_module(self.module) @@ -237,14 +241,20 @@ def is_cacheable(self) -> bool: return self.__is_cacheable -# TODO: Implement JavaFunction. class JavaFunction(Function): - """Descriptor for a java callable function.""" + """Descriptor for a Java callable function. + + Invocation goes through the JVM resource adapter, injected by the + runtime via :meth:`set_java_resource_adapter`; until then + ``__call__`` raises ``RuntimeError``. + """ qualname: str method_name: str parameter_types: List[str] + _j_resource_adapter: Any = PrivateAttr(default=None) + @model_serializer def __custom_serializer(self) -> dict[str, Any]: data = { @@ -255,8 +265,30 @@ def __custom_serializer(self) -> dict[str, Any]: } return data + def set_java_resource_adapter(self, adapter: Any) -> None: + """Inject the JVM adapter used to invoke this Java method.""" + self._j_resource_adapter = adapter + def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: - """Execute the stored function with provided arguments.""" + """Invoke the Java method via the JVM resource adapter. + + LLM tool calls always arrive as keyword arguments — positional + ``*args`` are ignored because the Java side reorders parameters + by name via reflection. + """ + if self._j_resource_adapter is None: + msg = ( + "JavaFunction requires the JVM resource adapter; not set " + "on this descriptor. The runtime should inject it via " + "set_java_resource_adapter before invocation." + ) + raise RuntimeError(msg) + return self._j_resource_adapter.invokeJavaTool( + self.qualname, + self.method_name, + self.parameter_types, + kwargs, + ) def check_signature(self, *args: Tuple[Any, ...]) -> None: """Check function signature is legal or not.""" diff --git a/python/flink_agents/plan/tests/test_function.py b/python/flink_agents/plan/tests/test_function.py index a42ac01e5..6ce4c8536 100644 --- a/python/flink_agents/plan/tests/test_function.py +++ b/python/flink_agents/plan/tests/test_function.py @@ -34,7 +34,7 @@ ) if TYPE_CHECKING: - from flink_agents.plan.function import Function + from flink_agents.api.function import Function def check_class(input_event: InputEvent, output_event: OutputEvent) -> None: @@ -256,10 +256,6 @@ def test_cache_performance_benefit() -> None: """Test that caching provides performance benefits.""" clear_python_function_cache() - # This test verifies that the same PythonFunction instance is reused - # We can't easily test performance directly, but we can verify that - # the cache key mechanism works correctly - # First call creates cache entry call_python_function( "flink_agents.plan.tests.test_function", "function_for_caching", (1,) @@ -285,7 +281,6 @@ def test_selective_caching_pure_functions() -> None: """Test that pure functions are cached.""" clear_python_function_cache() - # Pure functions should be cached call_python_function( "flink_agents.plan.tests.test_function", "simple_pure_function", (1, 2) ) @@ -293,7 +288,6 @@ def test_selective_caching_pure_functions() -> None: "flink_agents.plan.tests.test_function", "function_for_caching", (5,) ) - # Both should be in cache assert get_python_function_cache_size() == 2 cache_keys = get_python_function_cache_keys() assert ( @@ -310,17 +304,13 @@ def test_selective_caching_generator_functions() -> None: """Test that generator functions are not cached.""" clear_python_function_cache() - # Generator function should not be cached result = call_python_function( "flink_agents.plan.tests.test_function", "generator_function", (3,) ) - # Result is now a generator directly (no wrapper) assert isinstance(result, Generator) - # Convert generator to list for testing result_list = list(result) assert result_list == [0, 1, 2] - # Should not be cached assert get_python_function_cache_size() == 0 @@ -328,7 +318,6 @@ def test_selective_caching_mutable_defaults() -> None: """Test that functions with mutable defaults are not cached.""" clear_python_function_cache() - # Function with mutable default should not be cached result1 = call_python_function( "flink_agents.plan.tests.test_function", "function_with_mutable_default", () ) @@ -336,57 +325,35 @@ def test_selective_caching_mutable_defaults() -> None: "flink_agents.plan.tests.test_function", "function_with_mutable_default", () ) - # Should not be cached (each call creates a new function instance) assert get_python_function_cache_size() == 0 - # Results should be different if function is correctly not cached - # (mutable default behavior depends on not caching) assert isinstance(result1, list) assert isinstance(result2, list) def test_is_function_cacheable() -> None: """Test the _is_function_cacheable function directly.""" - # Pure functions should be cacheable assert _is_function_cacheable(simple_pure_function) is True assert _is_function_cacheable(function_for_caching) is True - - # Generator functions should not be cacheable assert _is_function_cacheable(generator_function) is False - - # Functions with mutable defaults should not be cacheable assert _is_function_cacheable(function_with_mutable_default) is False - - # Closures should not be cacheable closure_func = make_closure(5) assert _is_function_cacheable(closure_func) is False - - # None should not be cacheable assert _is_function_cacheable(None) is False def test_python_function_cacheability_optimization() -> None: """Test that PythonFunction caches the cacheability check result.""" - # Test cacheable function cacheable_func = PythonFunction.from_callable(simple_pure_function) - # First call should compute and cache the result assert cacheable_func.is_cacheable() is True - - # Second call should use cached result (we can't directly test this, - # but we can verify it returns the same result) assert cacheable_func.is_cacheable() is True - # Test non-cacheable function non_cacheable_func = PythonFunction.from_callable(generator_function) - # First call should compute and cache the result assert non_cacheable_func.is_cacheable() is False - - # Second call should use cached result assert non_cacheable_func.is_cacheable() is False - # Test that the cacheability check is consistent with direct _is_function_cacheable assert cacheable_func.is_cacheable() == _is_function_cacheable(simple_pure_function) assert non_cacheable_func.is_cacheable() == _is_function_cacheable( generator_function diff --git a/python/flink_agents/plan/tests/tools/resources/function_tool.json b/python/flink_agents/plan/tests/tools/resources/function_tool.json index 50dd0e98f..a218c83ff 100644 --- a/python/flink_agents/plan/tests/tools/resources/function_tool.json +++ b/python/flink_agents/plan/tests/tools/resources/function_tool.json @@ -1,4 +1,9 @@ { + "func": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.tests.tools.test_function_tool", + "qualname": "foo" + }, "metadata": { "name": "foo", "description": "Function for testing ToolMetadata.\n", @@ -22,10 +27,5 @@ "title": "foo", "type": "object" } - }, - "func": { - "module": "flink_agents.plan.tests.tools.test_function_tool", - "qualname": "foo", - "func_type": "PythonFunction" } -} \ No newline at end of file +} diff --git a/python/flink_agents/plan/tests/tools/test_function_tool.py b/python/flink_agents/plan/tests/tools/test_function_tool.py index d4a3f512c..18fec91ea 100644 --- a/python/flink_agents/plan/tests/tools/test_function_tool.py +++ b/python/flink_agents/plan/tests/tools/test_function_tool.py @@ -17,10 +17,12 @@ ################################################################################# import json from pathlib import Path +from unittest.mock import MagicMock import pytest -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import JavaFunction, PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool current_dir = Path(__file__).parent @@ -45,7 +47,7 @@ def foo(bar: int, baz: str) -> str: @pytest.fixture(scope="module") def func_tool() -> FunctionTool: - return from_callable(foo) + return FunctionTool(func=PythonFunction.from_callable(foo)) def test_serialize_function_tool(func_tool: FunctionTool) -> None: @@ -61,4 +63,123 @@ def test_deserialize_function_tool(func_tool: FunctionTool) -> None: with Path(f"{current_dir}/resources/function_tool.json").open() as f: json_value = f.read() actual_func_tool = FunctionTool.model_validate_json(json_value) - assert actual_func_tool == func_tool + # ``PythonFunction`` carries a private ``__func`` cache that is only + # populated once the callable has been resolved (e.g. via the eager + # metadata derivation in the fixture). The deserialized instance hasn't + # resolved the callable yet, so a full BaseModel ``==`` would differ on + # the cache. Compare the public, serialized state instead. + assert actual_func_tool.metadata == func_tool.metadata + assert actual_func_tool.func.module == func_tool.func.module + assert actual_func_tool.func.qualname == func_tool.func.qualname + + +def test_python_function_tool_metadata_filled_eagerly() -> None: + # ``PythonFunction`` metadata can be derived without external context, + # so ``FunctionTool`` fills it during model validation. The field is + # therefore already populated immediately after construction. + tool = FunctionTool(func=PythonFunction.from_callable(foo)) + assert tool.metadata is not None + assert tool.metadata.name == "foo" + + +# ---- Java function tool path ------------------------------------------------- + + +def _java_func() -> JavaFunction: + # Fresh instance per test — the adapter now lives on JavaFunction, so + # sharing one would leak state between tests. + return JavaFunction( + qualname="com.example.Tools", + method_name="add", + parameter_types=["int", "int"], + ) + +_FAKE_JAVA_SCHEMA = json.dumps( + { + "type": "object", + "properties": { + "a": {"type": "integer", "description": "First operand."}, + "b": {"type": "integer", "description": "Second operand."}, + }, + "required": ["a", "b"], + "title": "add", + } +) + + +def _fake_adapter() -> MagicMock: + """Build a mock ``_j_resource_adapter`` that mirrors the Java + ``JavaResourceAdapter`` surface used by ``plan.FunctionTool``. + + ``getJavaToolMetadata`` returns a flat ``Map`` (see + ``JavaResourceAdapter.getJavaToolMetadata`` Java side for why), + so mock it as a plain Python dict. + """ + adapter = MagicMock() + adapter.getJavaToolMetadata.return_value = { + "name": "add", + "description": "Add two ints.", + "inputSchema": _FAKE_JAVA_SCHEMA, + } + adapter.invokeJavaTool.return_value = 1065 + return adapter + + +def test_java_function_tool_constructs_without_adapter() -> None: + # Plan compile time: no JVM adapter yet. Construction (and its + # SerializableResource self-validation) must not call into the adapter. + # ``metadata`` stays ``None`` until the adapter is injected. + tool = FunctionTool(func=_java_func()) + assert tool.func._j_resource_adapter is None + assert isinstance(tool.func, JavaFunction) + assert tool.metadata is None + + +def test_java_function_tool_metadata_filled_on_adapter_injection() -> None: + tool = FunctionTool(func=_java_func()) + adapter = _fake_adapter() + + tool.set_java_resource_adapter(adapter) + + # Adapter is consulted exactly once at injection time and the result is + # stored in the regular ``metadata`` field. Subsequent accesses just read + # the field, so the adapter is not hit again. + adapter.getJavaToolMetadata.assert_called_once_with( + "com.example.Tools", "add", ["int", "int"] + ) + assert tool.metadata is not None + assert tool.metadata.name == "add" + assert tool.metadata.description == "Add two ints." + assert set(tool.metadata.args_schema.model_fields) == {"a", "b"} + _ = tool.metadata + adapter.getJavaToolMetadata.assert_called_once() + + +def test_java_function_tool_metadata_is_none_without_adapter() -> None: + # Before the runtime injects the adapter, the metadata is intentionally + # absent — this is the only legal window where ``Tool.metadata`` is + # ``None``. Accessing the field must not raise. + tool = FunctionTool(func=_java_func()) + assert tool.metadata is None + + +def test_java_function_tool_call_dispatches_through_adapter() -> None: + tool = FunctionTool(func=_java_func()) + adapter = _fake_adapter() + tool.set_java_resource_adapter(adapter) + + result = tool.call(a=377, b=688) + + assert result == 1065 + adapter.invokeJavaTool.assert_called_once_with( + "com.example.Tools", + "add", + ["int", "int"], + {"a": 377, "b": 688}, + ) + + +def test_java_function_tool_call_without_adapter_raises() -> None: + tool = FunctionTool(func=_java_func()) + with pytest.raises(RuntimeError, match="JVM resource adapter"): + tool.call(a=1, b=2) diff --git a/python/flink_agents/plan/tools/bash/bash_tool.py b/python/flink_agents/plan/tools/bash/bash_tool.py index d579c5acb..54797e97e 100644 --- a/python/flink_agents/plan/tools/bash/bash_tool.py +++ b/python/flink_agents/plan/tools/bash/bash_tool.py @@ -70,8 +70,6 @@ class BashTool(Tool): time by the framework (not visible to the LLM through ``args_schema``). """ - metadata: ToolMetadata = Field(exclude=True) - def __init__(self, **kwargs: Any) -> None: """Initialize the tool.""" super().__init__( diff --git a/python/flink_agents/plan/tools/function_tool.py b/python/flink_agents/plan/tools/function_tool.py index f686bf227..85be3df17 100644 --- a/python/flink_agents/plan/tools/function_tool.py +++ b/python/flink_agents/plan/tools/function_tool.py @@ -15,50 +15,89 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Callable +from typing import Any from docstring_parser import parse +from pydantic import model_validator from typing_extensions import override from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType -from flink_agents.api.tools.utils import create_schema_from_function +from flink_agents.api.tools.utils import ( + create_model_from_java_tool_schema_str, + create_schema_from_function, +) from flink_agents.plan.function import JavaFunction, PythonFunction class FunctionTool(Tool): - """Tool that takes in a function. + """Executable function tool. - Attributes: - ---------- - func : Function - User defined function. + ``metadata`` is filled eagerly as soon as the value is derivable — + during model validation for ``PythonFunction`` (from the callable's + docstring/signature), and inside :meth:`set_java_resource_adapter` + once the runtime injects the JVM bridge for ``JavaFunction``. Until + that injection the field stays ``None``. """ func: PythonFunction | JavaFunction + @model_validator(mode="after") + def _eager_derive_python_metadata(self) -> "FunctionTool": + if self.metadata is None and isinstance(self.func, PythonFunction): + self.metadata = _python_metadata(self.func) + return self + + def set_java_resource_adapter(self, adapter: Any) -> None: + """Inject the JVM resource adapter and derive ``metadata``. Called + by the runtime resource cache when the tool is first materialised; + no-op when ``func`` is not a ``JavaFunction``. + """ + if not isinstance(self.func, JavaFunction): + return + self.func.set_java_resource_adapter(adapter) + if self.metadata is None: + self.metadata = _java_metadata(self.func) + @classmethod @override def tool_type(cls) -> ToolType: """Get the tool type.""" return ToolType.FUNCTION + @override def call(self, *args: Any, **kwargs: Any) -> Any: - """Call the function tool.""" + """Invoke the underlying function.""" return self.func(*args, **kwargs) -def from_callable(func: Callable) -> FunctionTool: - """Create FunctionTool from a user defined function. - - Parameters - ---------- - func : Callable - The function to analyze. - """ - description = parse(func.__doc__).description - metadata = ToolMetadata( - name=func.__name__, +def _python_metadata(func: PythonFunction) -> ToolMetadata: + callable_ = func.as_callable() + description = parse(callable_.__doc__).description or "" + return ToolMetadata( + name=callable_.__name__, description=description, - args_schema=create_schema_from_function(func.__name__, func=func), + args_schema=create_schema_from_function(callable_.__name__, func=callable_), + ) + + +def _java_metadata(func: JavaFunction) -> ToolMetadata: + adapter = func._j_resource_adapter + if adapter is None: + msg = ( + "Java function tool metadata requires the JVM resource adapter; " + "not set on the underlying JavaFunction. The runtime should " + "inject it via FunctionTool.set_java_resource_adapter before " + "metadata access." + ) + raise RuntimeError(msg) + flat = adapter.getJavaToolMetadata( + func.qualname, func.method_name, func.parameter_types + ) + name = flat["name"] + return ToolMetadata( + name=name, + description=flat["description"], + args_schema=create_model_from_java_tool_schema_str( + name, flat["inputSchema"] + ), ) - return FunctionTool(func=PythonFunction.from_callable(func), metadata=metadata) diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py index 833057310..886e4c84c 100644 --- a/python/flink_agents/runtime/java/java_resource_wrapper.py +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -17,19 +17,29 @@ ################################################################################# from typing import Any, List -from pydantic import Field +from pydantic import ConfigDict, Field from typing_extensions import override from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.prompts.prompt import Prompt from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext -from flink_agents.api.tools.tool import Tool, ToolType +from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType class JavaTool(Tool): """Java Tool that carries tool metadata and can be recognized by PythonChatModel.""" + model_config = ConfigDict(populate_by_name=True) + + metadata_: ToolMetadata = Field(exclude=True, alias="metadata") + + @property + @override + def metadata(self) -> ToolMetadata: + """Return the tool metadata.""" + return self.metadata_ + @classmethod @override def tool_type(cls) -> ToolType: diff --git a/python/flink_agents/runtime/local_execution_environment.py b/python/flink_agents/runtime/local_execution_environment.py index 4ed7bd3ab..890963072 100644 --- a/python/flink_agents/runtime/local_execution_environment.py +++ b/python/flink_agents/runtime/local_execution_environment.py @@ -52,7 +52,7 @@ def __init__( self.__output = [] self.__config = config - def apply(self, agent: Agent) -> AgentBuilder: + def apply(self, agent: Agent | str) -> AgentBuilder: """Create local runner to execute given agent. Doesn't support apply multiple Agents. @@ -60,6 +60,14 @@ def apply(self, agent: Agent) -> AgentBuilder: if self.__runner is not None: err_msg = "LocalAgentBuilder doesn't support apply multiple agents." raise RuntimeError(err_msg) + if isinstance(agent, str): + if agent not in self.__env._agents: + msg = ( + f"No agent named {agent!r} is registered on this " + "environment. Did you call load_yaml first?" + ) + raise ValueError(msg) + agent = self.__env._agents[agent] # inspect resources from environment to agent instance. registered_resources = self.__env.resources for type, name_to_resource in registered_resources.items(): diff --git a/python/flink_agents/runtime/remote_execution_environment.py b/python/flink_agents/runtime/remote_execution_environment.py index 3d755520d..c5ac93f87 100644 --- a/python/flink_agents/runtime/remote_execution_environment.py +++ b/python/flink_agents/runtime/remote_execution_environment.py @@ -56,6 +56,7 @@ class RemoteAgentBuilder(AgentBuilder): __t_env: StreamTableEnvironment __config: AgentConfiguration __resources: Dict[ResourceType, Dict[str, Any]] = None + __agents: Dict[str, Agent] def __init__( self, @@ -63,12 +64,14 @@ def __init__( config: AgentConfiguration, t_env: StreamTableEnvironment | None = None, resources: Dict[ResourceType, Dict[str, Any]] | None = None, + agents: Dict[str, Agent] | None = None, ) -> None: """Init method of RemoteAgentBuilder.""" self.__input = input self.__t_env = t_env self.__config = config self.__resources = resources + self.__agents = agents or {} @property def t_env(self) -> StreamTableEnvironment: @@ -79,17 +82,26 @@ def t_env(self) -> StreamTableEnvironment: ) return self.__t_env - def apply(self, agent: Agent) -> "AgentBuilder": + def apply(self, agent: Agent | str) -> "AgentBuilder": """Set agent of execution environment. Parameters ---------- - agent : Agent - The agent user defined to run in execution environment. + agent : Agent | str + Either an Agent instance, or the name of an agent registered + on the environment (e.g. by ``load_yaml``). """ if self.__agent_plan is not None: err_msg = "RemoteAgentBuilder doesn't support apply multiple agents yet." raise RuntimeError(err_msg) + if isinstance(agent, str): + if agent not in self.__agents: + msg = ( + f"No agent named {agent!r} is registered on this " + "environment. Did you call load_yaml first?" + ) + raise ValueError(msg) + agent = self.__agents[agent] # inspect refer actions and resources from env to agent. for type, name_to_resource in self.__resources.items(): @@ -229,6 +241,7 @@ def from_datastream( config=self.__config, t_env=self.__t_env, resources=self.resources, + agents=self._agents, ) def from_table( @@ -255,6 +268,7 @@ def from_table( config=self.__config, t_env=self.t_env, resources=self.resources, + agents=self._agents, ) def from_list(self, input: List[Dict[str, Any]]) -> "AgentsExecutionEnvironment": diff --git a/python/flink_agents/runtime/resource_cache.py b/python/flink_agents/runtime/resource_cache.py index 0e7ded885..f66cc7588 100644 --- a/python/flink_agents/runtime/resource_cache.py +++ b/python/flink_agents/runtime/resource_cache.py @@ -20,7 +20,9 @@ from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext from flink_agents.plan.configuration import AgentConfiguration +from flink_agents.plan.function import JavaFunction from flink_agents.plan.resource_provider import JavaResourceProvider, ResourceProvider +from flink_agents.plan.tools.function_tool import FunctionTool class ResourceCache: @@ -85,6 +87,8 @@ def get_resource(self, name: str, type: ResourceType) -> Resource: resource = resource_provider.provide( resource_context=self._resource_context, config=self._config ) + if isinstance(resource, FunctionTool) and isinstance(resource.func, JavaFunction): + resource.set_java_resource_adapter(self._j_resource_adapter) resource.open() self._cache.setdefault(type, {})[name] = resource return resource diff --git a/python/flink_agents/runtime/skill/skill_tools.py b/python/flink_agents/runtime/skill/skill_tools.py index b5cb8ede7..3f5da277b 100644 --- a/python/flink_agents/runtime/skill/skill_tools.py +++ b/python/flink_agents/runtime/skill/skill_tools.py @@ -53,8 +53,6 @@ class LoadSkillTool(Tool): (not the public ResourceContext interface). """ - metadata: ToolMetadata = Field(exclude=True) - def __init__(self, **kwargs: Any) -> None: """Initialize the load skill tool.""" super().__init__( diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 279d46529..899f463d6 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -166,7 +166,8 @@ public void open() throws Exception { getRuntimeContext().getJobInfo().getJobId(), metricGroup, this::checkMailboxThread, - jobIdentifier); + jobIdentifier, + getRuntimeContext().getUserCodeClassLoader()); // Capture the wired Mem0 long-term memory, if any, so it can be plumbed into the Java // runner context created by ActionTaskContextManager. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java index 1f6b190df..73eb6a500 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -111,6 +111,9 @@ class PythonBridgeManager implements AutoCloseable { * @param metricGroup the agent metric group, exposed to Python via the runner context. * @param mailboxThreadChecker hook used by the runner context to assert mailbox-thread access. * @param jobIdentifier the job identifier used to scope Python state. + * @param userCodeClassLoader the operator's user-code class loader, propagated to {@link + * JavaResourceAdapter} so reflective Java tool resolution sees user jars added via {@code + * env.add_jars(...)}. */ void open( AgentPlan agentPlan, @@ -121,7 +124,8 @@ void open( JobID jobId, FlinkAgentsMetricGroupImpl metricGroup, Runnable mailboxThreadChecker, - String jobIdentifier) + String jobIdentifier, + ClassLoader userCodeClassLoader) throws Exception { boolean containPythonAction = agentPlan.getActions().values().stream() @@ -169,7 +173,8 @@ void open( throw new RuntimeException(e); } }), - pythonInterpreter); + pythonInterpreter, + userCodeClassLoader); if (containPythonResource || mem0Configured) { initPythonResourceAdapter(agentPlan, resourceCache); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index f17d7ce79..3c00d96dc 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -22,9 +22,16 @@ import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.plan.tools.FunctionTool; +import org.apache.flink.agents.plan.tools.ToolMetadataFactory; import pemja.core.PythonInterpreter; +import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -34,9 +41,21 @@ public class JavaResourceAdapter { private final transient PythonInterpreter interpreter; - public JavaResourceAdapter(ResourceContext resourceContext, PythonInterpreter interpreter) { + /** + * Class loader used to resolve Java tool methods declared by name. Captured at construction + * (the operator passes its {@code RuntimeContext.getUserCodeClassLoader()}) because pemja + * worker threads inherit the JVM system loader as their context loader and would not see + * user-supplied jars added via {@code env.add_jars(...)}. + */ + private final transient ClassLoader userCodeClassLoader; + + public JavaResourceAdapter( + ResourceContext resourceContext, + PythonInterpreter interpreter, + ClassLoader userCodeClassLoader) { this.resourceContext = resourceContext; this.interpreter = interpreter; + this.userCodeClassLoader = userCodeClassLoader; } /** @@ -105,4 +124,99 @@ public Document fromPythonDocument( Float score) { return new Document(content, metadata, id, embedding, score); } + + /** + * Resolve the metadata for a Java static tool method declared by fully-qualified class name, + * method name and parameter type names. + * + *

Invoked from the Python side via the {@code _j_resource_adapter} bridge when a {@code + * plan.FunctionTool} backed by a {@code JavaFunction} first materialises its metadata. + * Delegates to {@link ToolMetadataFactory#fromStaticMethod(Method)} once the {@code Method} is + * resolved, then flattens the resulting {@link ToolMetadata} into a {@code Map} + * before returning. + * + *

The flattening is required because pemja can crash with a SIGSEGV inside {@code + * JcpPyJObject_New} when Java returns an arbitrary Java object to a Python call that originated + * on a non-main interpreter thread (e.g. a Flink mailbox worker that resolves a tool's + * metadata). Returning only String fields — which pemja maps natively to {@code str} — + * sidesteps the reverse Java→Python object wrap entirely. The Python side rebuilds {@link + * ToolMetadata} from the flat map. + */ + public Map getJavaToolMetadata( + String className, String methodName, List parameterTypes) throws Exception { + Method method = resolveMethod(className, methodName, parameterTypes); + ToolMetadata metadata = ToolMetadataFactory.fromStaticMethod(method); + Map result = new HashMap<>(); + result.put("name", metadata.getName()); + result.put("description", metadata.getDescription()); + result.put("inputSchema", metadata.getInputSchema()); + return result; + } + + /** + * Invoke a Java static tool method with keyword arguments coming from a Python tool call. + * + *

Delegates to {@link FunctionTool#call(ToolParameters)} so the Python-driven tool-call path + * shares every detail of argument resolution with the Java agent path — {@link + * org.apache.flink.agents.api.annotation.ToolParam} name override, {@link ToolParameters} + * numeric coercion (covers the LLM-emitted JSON Number → Java box type mismatch that reflective + * {@code Method.invoke} otherwise rejects), required-parameter checking, and {@link + * ToolResponse} success / error semantics. The success result is unwrapped for the Python + * caller; an unsuccessful response is re-thrown as a {@link RuntimeException}. + */ + public Object invokeJavaTool( + String className, + String methodName, + List parameterTypes, + Map arguments) + throws Exception { + Method method = resolveMethod(className, methodName, parameterTypes); + FunctionTool tool = FunctionTool.fromStaticMethod(method); + ToolResponse response = + tool.call(new ToolParameters(arguments == null ? new HashMap<>() : arguments)); + if (!response.isSuccess()) { + throw new RuntimeException(response.getError()); + } + return response.getResult(); + } + + private Method resolveMethod(String className, String methodName, List parameterTypes) + throws ClassNotFoundException, NoSuchMethodException { + ClassLoader classLoader = + userCodeClassLoader != null + ? userCodeClassLoader + : Thread.currentThread().getContextClassLoader(); + Class clazz = Class.forName(className, true, classLoader); + Class[] paramClasses = new Class[parameterTypes.size()]; + for (int i = 0; i < parameterTypes.size(); i++) { + paramClasses[i] = resolveType(parameterTypes.get(i), classLoader); + } + return clazz.getMethod(methodName, paramClasses); + } + + private static Class resolveType(String typeName, ClassLoader classLoader) + throws ClassNotFoundException { + switch (typeName) { + case "boolean": + return boolean.class; + case "byte": + return byte.class; + case "short": + return short.class; + case "int": + return int.class; + case "long": + return long.class; + case "float": + return float.class; + case "double": + return double.class; + case "char": + return char.class; + case "void": + return void.class; + default: + return Class.forName(typeName, true, classLoader); + } + } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java index f7226826e..9ee685638 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java @@ -50,7 +50,8 @@ void openIsNoOpWhenPlanHasNeitherPythonActionsNorResources() throws Exception { /* jobId */ new JobID(), /* metricGroup */ null, /* mailboxThreadChecker */ () -> {}, - /* jobIdentifier */ "job-1"); + /* jobIdentifier */ "job-1", + /* userCodeClassLoader */ Thread.currentThread().getContextClassLoader()); // No-op contract: nothing initialized, no Pemja interpreter created. assertThat(bridge.isInitialized()).isFalse(); diff --git a/tools/.rat-excludes b/tools/.rat-excludes index ff986781b..dceb40d37 100644 --- a/tools/.rat-excludes +++ b/tools/.rat-excludes @@ -17,4 +17,5 @@ PULL_REQUEST_TEMPLATE.md .ruff_cache/* .*\.egg-info/* licenses/* -skills/* \ No newline at end of file +skills/* +.*\.yaml$ \ No newline at end of file