Skip to content
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ToolManager(Generic[AgentDepsT]):
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
"""Names of tools that failed in this run step."""
default_max_retries: int = 1
"""Default number of times to retry a tool"""

@classmethod
@contextmanager
Expand Down Expand Up @@ -62,6 +64,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
toolset=self.toolset,
ctx=ctx,
tools=await self.toolset.get_tools(ctx),
default_max_retries=self.default_max_retries,
)

@property
Expand Down Expand Up @@ -174,7 +177,7 @@ async def _call_tool(

return await self.toolset.call_tool(name, args_dict, ctx, tool)
except (ValidationError, ModelRetry) as e:
max_retries = tool.max_retries if tool is not None else 1
max_retries = tool.max_retries if tool is not None else self.default_max_retries
current_retry = self.ctx.retries.get(name, 0)

if current_retry == max_retries:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ async def main():
output_toolset.max_retries = self._max_result_retries
output_toolset.output_validators = output_validators
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
tool_manager = ToolManager[AgentDepsT](toolset)
tool_manager = ToolManager[AgentDepsT](toolset, default_max_retries=self._max_tool_retries)

# Build the graph
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,73 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
)


def test_unknown_tool_multiple_retries():
num_retries = 2

def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[ToolCallPart('foobar', '{}')])

def foo() -> str:
return 'Hello from foo'

agent = Agent(FunctionModel(empty), tools=[foo], retries=num_retries)

with capture_run_messages() as messages:
with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'):
agent.run_sync('Hello')
assert messages == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=51, output_tokens=2),
model_name='function:empty:',
timestamp=IsNow(tz=timezone.utc),
run_id=IsStr(),
),
ModelRequest(
parts=[
RetryPromptPart(
tool_name='foobar',
content="Unknown tool name: 'foobar'. Available tools: 'foo'",
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=65, output_tokens=4),
model_name='function:empty:',
timestamp=IsNow(tz=timezone.utc),
run_id=IsStr(),
),
ModelRequest(
parts=[
RetryPromptPart(
tool_name='foobar',
content="Unknown tool name: 'foobar'. Available tools: 'foo'",
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=79, output_tokens=6),
model_name='function:empty:',
timestamp=IsNow(tz=timezone.utc),
run_id=IsStr(),
),
]
)


def test_tool_exceeds_token_limit_error():
def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')])
Expand Down
Loading