Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BuiltinToolCallPart,
BuiltinToolReturnPart,
FilePart,
FinishReason as PydanticFinishReason,
FunctionToolResultEvent,
RetryPromptPart,
TextPart,
Expand All @@ -23,6 +24,7 @@
ToolCallPartDelta,
)
from ...output import OutputDataT
from ...run import AgentRunResultEvent
from ...tools import AgentDepsT
from .. import UIEventStream
from .request_types import RequestData
Expand All @@ -32,6 +34,7 @@
ErrorChunk,
FileChunk,
FinishChunk,
FinishReason,
FinishStepChunk,
ReasoningDeltaChunk,
ReasoningEndChunk,
Expand All @@ -48,6 +51,15 @@
ToolOutputErrorChunk,
)

# Map Pydantic AI finish reasons to Vercel AI format
_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = {
'stop': 'stop',
'length': 'length',
'content_filter': 'content-filter',
'tool_call': 'tool-calls',
'error': 'error',
}

__all__ = ['VercelAIEventStream']

# See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
Expand All @@ -64,6 +76,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp
"""UI event stream transformer for the Vercel AI protocol."""

_step_started: bool = False
_finish_reason: FinishReason = None

@property
def response_headers(self) -> Mapping[str, str] | None:
Expand All @@ -85,10 +98,18 @@ async def before_response(self) -> AsyncIterator[BaseChunk]:
async def after_stream(self) -> AsyncIterator[BaseChunk]:
yield FinishStepChunk()

yield FinishChunk()
yield FinishChunk(finish_reason=self._finish_reason)
yield DoneChunk()

async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]:
pydantic_reason = event.result.response.finish_reason
if pydantic_reason:
self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason)
return
yield

async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]:
self._finish_reason = 'error'
yield ErrorChunk(error_text=str(error))

async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]:
Expand Down
21 changes: 21 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
ProviderMetadata = dict[str, dict[str, JSONValue]]
"""Provider metadata."""

FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None
"""Reason why the model finished generating."""


class BaseChunk(CamelBaseModel, ABC):
"""Abstract base class for response SSE events."""
Expand Down Expand Up @@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk):
dynamic: bool | None = None


class ToolApprovalRequestChunk(BaseChunk):
"""Tool approval request chunk for human-in-the-loop approval."""

type: Literal['tool-approval-request'] = 'tool-approval-request'
approval_id: str
tool_call_id: str


class ToolOutputDeniedChunk(BaseChunk):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever receive (the part version) of this, and should we handle it? We'll likely need it for https://ai.pydantic.dev/deferred-tools/#human-in-the-loop-tool-approval

"""Tool output denied chunk when user denies tool execution."""

type: Literal['tool-output-denied'] = 'tool-output-denied'
tool_call_id: str


class SourceUrlChunk(BaseChunk):
"""Source URL chunk."""

Expand Down Expand Up @@ -178,7 +196,9 @@ class DataChunk(BaseChunk):
"""Data chunk with dynamic type."""

type: Annotated[str, Field(pattern=r'^data-')]
id: str | None = None
data: Any
transient: bool | None = None


class StartStepChunk(BaseChunk):
Expand All @@ -205,6 +225,7 @@ class FinishChunk(BaseChunk):
"""Finish chunk."""

type: Literal['finish'] = 'finish'
finish_reason: FinishReason = None
message_metadata: Any | None = None


Expand Down
40 changes: 36 additions & 4 deletions tests/test_vercel_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def client_response\
{'type': 'text-delta', 'delta': ' bodies safely?', 'id': IsStr()},
{'type': 'text-end', 'id': IsStr()},
{'type': 'finish-step'},
{'type': 'finish'},
{'type': 'finish', 'finishReason': 'stop'},
'[DONE]',
]
)
Expand Down Expand Up @@ -1488,7 +1488,7 @@ async def stream_function(
{'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}},
{'type': 'error', 'errorText': 'Exceeded maximum retries (1) for output validation'},
{'type': 'finish-step'},
{'type': 'finish'},
{'type': 'finish', 'finishReason': 'error'},
'[DONE]',
]
)
Expand Down Expand Up @@ -1531,7 +1531,7 @@ async def tool(query: str) -> str:
},
{'type': 'error', 'errorText': 'Unknown tool'},
{'type': 'finish-step'},
{'type': 'finish'},
{'type': 'finish', 'finishReason': 'error'},
'[DONE]',
]
)
Expand Down Expand Up @@ -1572,7 +1572,7 @@ def raise_error(run_result: AgentRunResult[Any]) -> None:
{'type': 'text-end', 'id': IsStr()},
{'type': 'error', 'errorText': 'Faulty on_complete'},
{'type': 'finish-step'},
{'type': 'finish'},
{'type': 'finish', 'finishReason': 'error'},
'[DONE]',
]
)
Expand Down Expand Up @@ -1619,6 +1619,38 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun
)


async def test_data_chunk_with_id_and_transient():
"""Test DataChunk supports optional id and transient fields for AI SDK compatibility."""
agent = Agent(model=TestModel())

request = SubmitMessage(
id='foo',
messages=[
UIMessage(
id='bar',
role='user',
parts=[TextUIPart(text='Hello')],
),
],
)

async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChunk]:
# Yield a data chunk with id for reconciliation
yield DataChunk(type='data-task', id='task-123', data={'status': 'complete'})
# Yield a transient data chunk (not persisted to history)
yield DataChunk(type='data-progress', data={'percent': 100}, transient=True)

adapter = VercelAIAdapter(agent, request)
events = [
'[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: '))
async for event in adapter.encode_stream(adapter.run_stream(on_complete=on_complete))
]

# Verify the data chunks are present in the events with correct fields
assert {'type': 'data-task', 'id': 'task-123', 'data': {'status': 'complete'}} in events
assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events


@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed')
async def test_adapter_dispatch_request():
agent = Agent(model=TestModel())
Expand Down