Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import AsyncIterator, Mapping
from dataclasses import dataclass
from typing import Any
from uuid import uuid4

from pydantic_core import to_json

Expand All @@ -13,7 +14,9 @@
BuiltinToolCallPart,
BuiltinToolReturnPart,
FilePart,
FinishReason as PydanticFinishReason,
FunctionToolResultEvent,
ModelResponse,
RetryPromptPart,
TextPart,
TextPartDelta,
Expand All @@ -23,7 +26,8 @@
ToolCallPartDelta,
)
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...run import AgentRunResultEvent
from ...tools import AgentDepsT, DeferredToolRequests
from .. import UIEventStream
from .request_types import RequestData
from .response_types import (
Expand All @@ -32,6 +36,7 @@
ErrorChunk,
FileChunk,
FinishChunk,
FinishReason,
FinishStepChunk,
ReasoningDeltaChunk,
ReasoningEndChunk,
Expand All @@ -41,13 +46,23 @@
TextDeltaChunk,
TextEndChunk,
TextStartChunk,
ToolApprovalRequestChunk,
ToolInputAvailableChunk,
ToolInputDeltaChunk,
ToolInputStartChunk,
ToolOutputAvailableChunk,
ToolOutputErrorChunk,
)

# Map Pydantic AI finish reasons to Vercel AI format
_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = {
'stop': 'stop',
'length': 'length',
'content_filter': 'content-filter',
'tool_call': 'tool-calls',
'error': 'error',
}

__all__ = ['VercelAIEventStream']

# See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
Expand All @@ -64,6 +79,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp
"""UI event stream transformer for the Vercel AI protocol."""

_step_started: bool = False
_finish_reason: FinishReason = None

@property
def response_headers(self) -> Mapping[str, str] | None:
Expand All @@ -85,9 +101,25 @@ async def before_response(self) -> AsyncIterator[BaseChunk]:
async def after_stream(self) -> AsyncIterator[BaseChunk]:
yield FinishStepChunk()

yield FinishChunk()
yield FinishChunk(finish_reason=self._finish_reason)
yield DoneChunk()

async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]:
messages = event.result.all_messages()
if messages and isinstance(messages[-1], ModelResponse):
pydantic_reason = messages[-1].finish_reason
if pydantic_reason:
self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason)

# Emit tool approval requests for deferred approvals
output = event.result.output
if isinstance(output, DeferredToolRequests):
for tool_call in output.approvals:
yield ToolApprovalRequestChunk(
approval_id=str(uuid4()),
tool_call_id=tool_call.tool_call_id,
)

async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]:
yield ErrorChunk(error_text=str(error))

Expand Down
21 changes: 21 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
ProviderMetadata = dict[str, dict[str, JSONValue]]
"""Provider metadata."""

FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None
"""Reason why the model finished generating."""


class BaseChunk(CamelBaseModel, ABC):
"""Abstract base class for response SSE events."""
Expand Down Expand Up @@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk):
dynamic: bool | None = None


class ToolApprovalRequestChunk(BaseChunk):
"""Tool approval request chunk for human-in-the-loop approval."""

type: Literal['tool-approval-request'] = 'tool-approval-request'
approval_id: str
tool_call_id: str


class ToolOutputDeniedChunk(BaseChunk):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever receive (the part version) of this, and should we handle it? We'll likely need it for https://ai.pydantic.dev/deferred-tools/#human-in-the-loop-tool-approval

"""Tool output denied chunk when user denies tool execution."""

type: Literal['tool-output-denied'] = 'tool-output-denied'
tool_call_id: str


class SourceUrlChunk(BaseChunk):
"""Source URL chunk."""

Expand Down Expand Up @@ -178,7 +196,9 @@ class DataChunk(BaseChunk):
"""Data chunk with dynamic type."""

type: Annotated[str, Field(pattern=r'^data-')]
id: str | None = None
data: Any
transient: bool | None = None


class StartStepChunk(BaseChunk):
Expand All @@ -205,6 +225,7 @@ class FinishChunk(BaseChunk):
"""Finish chunk."""

type: Literal['finish'] = 'finish'
finish_reason: FinishReason = None
message_metadata: Any | None = None


Expand Down
84 changes: 83 additions & 1 deletion tests/test_vercel_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def client_response\
{'type': 'text-delta', 'delta': ' bodies safely?', 'id': IsStr()},
{'type': 'text-end', 'id': IsStr()},
{'type': 'finish-step'},
{'type': 'finish'},
{'type': 'finish', 'finishReason': 'stop'},
'[DONE]',
]
)
Expand Down Expand Up @@ -1619,6 +1619,88 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun
)


async def test_data_chunk_with_id_and_transient():
"""Test DataChunk supports optional id and transient fields for AI SDK compatibility."""
agent = Agent(model=TestModel())

request = SubmitMessage(
id='foo',
messages=[
UIMessage(
id='bar',
role='user',
parts=[TextUIPart(text='Hello')],
),
],
)

async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChunk]:
# Yield a data chunk with id for reconciliation
yield DataChunk(type='data-task', id='task-123', data={'status': 'complete'})
# Yield a transient data chunk (not persisted to history)
yield DataChunk(type='data-progress', data={'percent': 100}, transient=True)

adapter = VercelAIAdapter(agent, request)
events = [
'[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: '))
async for event in adapter.encode_stream(adapter.run_stream(on_complete=on_complete))
]

# Verify the data chunks are present in the events with correct fields
assert {'type': 'data-task', 'id': 'task-123', 'data': {'status': 'complete'}} in events
assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events


async def test_tool_approval_request_emission():
"""Test that ToolApprovalRequestChunk is emitted when tools require approval."""
from pydantic_ai.tools import DeferredToolRequests

async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
yield {
0: DeltaToolCall(
name='delete_file',
json_args='{"path": "test.txt"}',
tool_call_id='delete_1',
)
}

agent: Agent[None, str | DeferredToolRequests] = Agent(
model=FunctionModel(stream_function=stream_function), output_type=[str, DeferredToolRequests]
)

@agent.tool_plain(requires_approval=True)
def delete_file(path: str) -> str:
return f'Deleted {path}'

request = SubmitMessage(
id='foo',
messages=[
UIMessage(
id='bar',
role='user',
parts=[TextUIPart(text='Delete test.txt')],
),
],
)

adapter = VercelAIAdapter(agent, request)
events: list[str | dict[str, Any]] = [
'[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: '))
async for event in adapter.encode_stream(adapter.run_stream())
]

# Verify tool-approval-request chunk is emitted with UUID approval_id
approval_event: dict[str, Any] | None = next(
(e for e in events if isinstance(e, dict) and e.get('type') == 'tool-approval-request'),
None,
)
assert approval_event is not None
assert approval_event['toolCallId'] == 'delete_1'
assert 'approvalId' in approval_event


@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed')
async def test_adapter_dispatch_request():
agent = Agent(model=TestModel())
Expand Down
Loading