diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index a2730bd852..f9de3a00a1 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -35,6 +35,8 @@ class ToolManager(Generic[AgentDepsT]): """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) """Names of tools that failed in this run step.""" + default_max_retries: int = 1 + """Default number of times to retry a tool""" @classmethod @contextmanager @@ -62,6 +64,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe toolset=self.toolset, ctx=ctx, tools=await self.toolset.get_tools(ctx), + default_max_retries=self.default_max_retries, ) @property @@ -174,7 +177,7 @@ async def _call_tool( return await self.toolset.call_tool(name, args_dict, ctx, tool) except (ValidationError, ModelRetry) as e: - max_retries = tool.max_retries if tool is not None else 1 + max_retries = tool.max_retries if tool is not None else self.default_max_retries current_retry = self.ctx.retries.get(name, 0) if current_retry == max_retries: diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 9ecd5769a2..f5d2b72075 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -616,7 +616,7 @@ async def main(): output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - tool_manager = ToolManager[AgentDepsT](toolset) + tool_manager = ToolManager[AgentDepsT](toolset, default_max_retries=self._max_tool_retries) # Build the graph graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) diff --git a/tests/test_agent.py b/tests/test_agent.py index 0dfb0245fa..fec7724d2a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2988,6 +2988,73 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ) +def test_unknown_tool_multiple_retries(): + num_retries = 2 + + def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) + + agent = Agent(FunctionModel(empty), retries=num_retries) + + with capture_run_messages() as messages: + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'): + agent.run_sync('Hello') + assert messages == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:empty:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='foobar', + content="Unknown tool name: 'foobar'. No tools available.", + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ], + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=65, output_tokens=4), + model_name='function:empty:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='foobar', + content="Unknown tool name: 'foobar'. No tools available.", + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ], + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=79, output_tokens=6), + model_name='function:empty:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ] + ) + + def test_tool_exceeds_token_limit_error(): def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')])