diff --git a/.gitignore b/.gitignore index 037b88b..e7551df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ -# Claude +# Agents +.agents/ .claude/ # Python-generated files diff --git a/alembic/versions/19e2c92563e2_update_message_status.py b/alembic/versions/19e2c92563e2_update_message_status.py new file mode 100644 index 0000000..e2cd97e --- /dev/null +++ b/alembic/versions/19e2c92563e2_update_message_status.py @@ -0,0 +1,41 @@ +"""Add MODEL_CALL_LIMIT and INTERRUPTED message statuses. + +Revision ID: 19e2c92563e2 +Revises: 1c6556bb74f2 +Create Date: 2026-05-14 16:02:45.217643 +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "19e2c92563e2" +down_revision: Union[str, Sequence[str], None] = "1c6556bb74f2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add MODEL_CALL_LIMIT and INTERRUPTED to the messagestatus enum. + + Note: `ALTER TYPE ... ADD VALUE` cannot run inside a transaction block + on PostgreSQL pre-v12, so we use an autocommit block. + """ + with op.get_context().autocommit_block(): + op.execute( + "ALTER TYPE messagestatus ADD VALUE IF NOT EXISTS 'MODEL_CALL_LIMIT'" + ) + op.execute("ALTER TYPE messagestatus ADD VALUE IF NOT EXISTS 'INTERRUPTED'") + + +def downgrade() -> None: + """Downgrade is intentionally unsupported. + + Postgres cannot drop enum values, and rebuilding the type would require + remapping existing MODEL_CALL_LIMIT and INTERRUPTED rows to another status, + losing the diagnostic signal these statuses were added to capture. + """ + raise NotImplementedError( + "Downgrade not supported: removing enum values would silently rewrite rows." + ) diff --git a/app/api/dependencies/__init__.py b/app/api/dependencies/__init__.py index f5607db..47fdbe6 100644 --- a/app/api/dependencies/__init__.py +++ b/app/api/dependencies/__init__.py @@ -1,4 +1,4 @@ -from .agent import Agent, get_agent +from .agent import Agent, RunningRuns, get_agent, get_running_runs from .auth import UserID, get_user_id from .db import AsyncDB, get_database, get_session from .feedback import FeedbackSender, get_feedback_sender @@ -7,10 +7,12 @@ "Agent", "AsyncDB", "FeedbackSender", + "RunningRuns", "UserID", "get_agent", "get_database", "get_feedback_sender", + "get_running_runs", "get_session", "get_user_id", ] diff --git a/app/api/dependencies/agent.py b/app/api/dependencies/agent.py index ef0fd53..21204c3 100644 --- a/app/api/dependencies/agent.py +++ b/app/api/dependencies/agent.py @@ -1,3 +1,4 @@ +import asyncio from typing import Annotated from fastapi import Depends, Request @@ -8,4 +9,9 @@ def get_agent(request: Request) -> CompiledStateGraph: return request.app.state.agent +def get_running_runs(request: Request) -> dict[str, asyncio.Task]: + return request.app.state.running_runs + + Agent = Annotated[CompiledStateGraph, Depends(get_agent)] +RunningRuns = Annotated[dict[str, asyncio.Task], Depends(get_running_runs)] diff --git a/app/api/routers/chatbot.py b/app/api/routers/chatbot.py index f21a54f..a94ec76 100644 --- a/app/api/routers/chatbot.py +++ b/app/api/routers/chatbot.py @@ -3,10 +3,12 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, status from fastapi.responses import StreamingResponse +from loguru import logger -from app.api.dependencies import Agent, AsyncDB, FeedbackSender, UserID +from app.api.dependencies import Agent, AsyncDB, FeedbackSender, RunningRuns, UserID from app.api.schemas import ConfigDict, UserMessage -from app.api.streaming import stream_response +from app.api.streaming import run_agent, stream_events +from app.api.streaming.schemas import StreamEvent from app.db.models import ( FeedbackCreate, FeedbackPayload, @@ -96,14 +98,19 @@ async def list_messages( return await database.get_messages(thread.id, order_by) -@router.post("/threads/{thread_id}/messages") +@router.post( + "/threads/{thread_id}/messages", + response_class=StreamingResponse, + status_code=status.HTTP_201_CREATED, +) async def send_message( thread_id: str, user_message: UserMessage, - agent: Agent, database: AsyncDB, + agent: Agent, + running_runs: RunningRuns, user_id: UserID, -) -> Message: +) -> StreamingResponse: run_id = str(uuid.uuid4()) config = ConfigDict( @@ -120,15 +127,35 @@ async def send_message( message = await database.create_message(message_create) - return StreamingResponse( - stream_response( - database=database, + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + task = asyncio.create_task( + run_agent( agent=agent, - user_message=message, config=config, thread_id=thread_id, + user_message=message, model_uri=settings.MODEL_URI, + queue=queue, ), + name=f"run_agent:{run_id}", + ) + + running_runs[run_id] = task + + def _cleanup(task: asyncio.Task): # pragma: no cover + del running_runs[run_id] + if task.cancelled(): + logger.warning(f"run_agent task {run_id} was cancelled mid-run") + return + e = task.exception() + if e is not None: + logger.opt(exception=e).error(f"run_agent task {run_id} crashed mid-run:") + + task.add_done_callback(_cleanup) + + return StreamingResponse( + stream_events(queue), status_code=status.HTTP_201_CREATED, ) diff --git a/app/api/streaming/__init__.py b/app/api/streaming/__init__.py index 5748056..a36d133 100644 --- a/app/api/streaming/__init__.py +++ b/app/api/streaming/__init__.py @@ -1,3 +1,4 @@ -from app.api.streaming.stream import stream_response +from app.api.streaming.agent_runner import run_agent +from app.api.streaming.stream import stream_events -__all__ = ["stream_response"] +__all__ = ["run_agent", "stream_events"] diff --git a/app/api/streaming/agent_runner.py b/app/api/streaming/agent_runner.py new file mode 100644 index 0000000..73dbbf3 --- /dev/null +++ b/app/api/streaming/agent_runner.py @@ -0,0 +1,303 @@ +import asyncio +import json +from typing import Any + +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.graph.state import CompiledStateGraph +from loguru import logger + +from app.api.schemas import ConfigDict +from app.api.streaming.schemas import EventData, StreamEvent, ToolCall, ToolOutput +from app.api.streaming.security import sanitize_markdown_links +from app.db.database import AsyncDatabase, sessionmaker +from app.db.models import Message, MessageCreate, MessageRole, MessageStatus + + +class ErrorMessage: + INTERRUPTED = ( + "A conexão com o servidor foi interrompida. Por favor, tente novamente." + ) + + MODEL_CALL_LIMIT_REACHED = ( + "Essa pergunta gerou um raciocínio muito longo e não consegui chegar a uma conclusão. " + "Por favor, tente ser mais específico ou divida sua pergunta em partes menores." + ) + + UNEXPECTED = "Ocorreu um erro inesperado. Por favor, tente novamente. Se o problema persistir, avise-nos." + + +def _truncate_json( + json_string: str, max_list_len: int = 10, max_str_len: int = 300 +) -> str: + """Iteratively truncates a serialized JSON object by shortening lists and strings + and adding human-readable placeholders. + + Note: + This function only processes JSON objects (dictionaries). If the serialized JSON + represents any other type, the original JSON string will be returned unchanged. + + Args: + json_string (str): The serialized JSON to process. + max_list_len (int, optional): The max number of items to keep in a list. Defaults to 10. + max_str_len (int, optional): The max length for any single string. Defaults to 300. + + Returns: + str: The truncated, formatted, and serialized JSON object. + """ + try: + data = json.loads(json_string) + except json.JSONDecodeError: + return json_string + + if not isinstance(data, dict): + return json_string + + stack = [data] + + while stack: + current_node = stack.pop() + + if isinstance(current_node, dict): + items_to_process = current_node.items() + else: + items_to_process = enumerate(current_node) + + for key_or_idx, item in items_to_process: + if isinstance(item, str): + if len(item) > max_str_len: + truncated_str = ( + item[:max_str_len] + + f"... ({len(item) - max_str_len} more characters)" + ) + current_node[key_or_idx] = truncated_str + + elif isinstance(item, list): + if len(item) > max_list_len: + original_len = len(item) + del item[max_list_len:] + item.append(f"... ({original_len - max_list_len} more items)") + stack.append(item) + + elif isinstance(item, dict): + stack.append(item) + + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _parse_thinking(message: AIMessage) -> str | None: + """Parse thinking content from an AI message. + + Some models (e.g., Gemini 3) return `message.content` as a list of typed blocks, + which may include `{"type": "thinking", "thinking": "..."}` entries. When + `content` is a plain string, no thinking is available. + + Args: + message (AIMessage): The AI message from where to parse the thinking. + + Returns: + str | None: The concatenated thinking text, or None if no thinking blocks exist. + """ + if isinstance(message.content, str): + return None + + blocks = [ + block + for block in message.content + if isinstance(block, dict) + and block.get("type") == "thinking" + and isinstance(block.get("thinking"), str) + ] + + thinking = "".join(block["thinking"] for block in blocks) + + return thinking or None + + +def _process_chunk(chunk: dict[str, Any]) -> StreamEvent | None: + """Process a streaming chunk from a react agent workflow into a standardized StreamEvent. + + Args: + chunk (dict[str, Any]): Raw chunk from agent workflow. + Only processes "agent" and "tools" nodes. + + Returns: + StreamEvent | None: Structured event or None if the chunk is ignored: + - "tool_call" for agent messages with tool calls + - "tool_output" for tool execution results + - "final_answer" for agent messages without tool calls + - None for ignored chunks + """ + if "model" in chunk: + ai_messages: list[AIMessage] = chunk["model"]["messages"] + + # If no messages are returned, the model returned an empty response + # with no tool calls. This also counts as a final (but empty) answer. + if not ai_messages: + return StreamEvent(type="final_answer", data=EventData(content="")) + + message = ai_messages[0] + + if message.tool_calls: + event_type = "tool_call" + tool_calls = [ + ToolCall( + id=tool_call["id"], name=tool_call["name"], args=tool_call["args"] + ) + for tool_call in message.tool_calls + ] + content = _parse_thinking(message) or message.text + else: + event_type = "final_answer" + tool_calls = None + content = sanitize_markdown_links(message.text) + + event_data = EventData(content=content, tool_calls=tool_calls) + + return StreamEvent(type=event_type, data=event_data) + elif "tools" in chunk: + updates = chunk["tools"] + + # single tool call + if isinstance(updates, dict): + tool_messages: list[ToolMessage] = updates["messages"] + + # multiple parallel tool calls + elif isinstance(updates, list): + tool_messages: list[ToolMessage] = [ + update["messages"][0] for update in updates if "messages" in update + ] + + # defensive handling (langgraph should only output dicts and lists) + else: + tool_messages = [] + + tool_outputs = [ + ToolOutput( + status=message.status, + tool_call_id=message.tool_call_id, + tool_name=message.name, + content=_truncate_json(message.content), + artifact=message.artifact, + ) + for message in tool_messages + ] + + return StreamEvent( + type="tool_output", data=EventData(tool_outputs=tool_outputs) + ) + elif "ModelCallLimitMiddleware.before_model" in chunk: + # before_model runs on every model iteration; only the limit-exceeded + # path sets jump_to="end", so check that rather than the key's presence. + update = chunk["ModelCallLimitMiddleware.before_model"] or {} + if update.get("jump_to") == "end": + event_data = EventData( + content=ErrorMessage.MODEL_CALL_LIMIT_REACHED, tool_calls=None + ) + return StreamEvent(type="final_answer", data=event_data) + return None + + +async def run_agent( + agent: CompiledStateGraph, + config: ConfigDict, + thread_id: str, + user_message: Message, + model_uri: str, + queue: asyncio.Queue[StreamEvent], +): + """Run the agent to completion and push events onto the queue. + + Owns persistence: writes the assistant `messages` row in `finally` and + emits a terminal `complete` event carrying either the persisted run_id + (on success) or `error_details` (if persistence fails). Exactly one + `complete` event is emitted per run. + + Args: + agent (CompiledStateGraph): Agent compiled state graph. + config (ConfigDict): Config for agent execution. + thread_id (str): Thread unique identifier. + user_message (Message): User message. + model_uri (str): Model URI. + queue (asyncio.Queue[StreamEvent]): Events queue. + """ + events = [] + artifacts = [] + assistant_message = "" + status: MessageStatus | None = None + + try: + async for mode, chunk in agent.astream( # pragma: no cover + input={"messages": [{"role": "user", "content": user_message.content}]}, + config=config, + stream_mode=["updates", "values"], + ): + if mode == "values": + continue + + event = _process_chunk(chunk) + + if event is None: + continue + + if event.type == "tool_output": + for output in event.data.tool_outputs: + if output.artifact: + artifacts.append(output.artifact) + elif event.type == "final_answer": + assistant_message = event.data.content + # Distinguish model-call-limit from a normal final answer (fragile) + if assistant_message == ErrorMessage.MODEL_CALL_LIMIT_REACHED: + status = MessageStatus.MODEL_CALL_LIMIT + else: + status = MessageStatus.SUCCESS + + events.append(event.model_dump()) + await queue.put(event) + except asyncio.CancelledError: + if status is None: + assistant_message = ErrorMessage.INTERRUPTED + status = MessageStatus.INTERRUPTED + raise + except Exception: + logger.exception(f"Unexpected error in run {config['run_id']}:") + assistant_message = ErrorMessage.UNEXPECTED + status = MessageStatus.ERROR + event = StreamEvent( + type="error", + data=EventData( + content=assistant_message, + error_details={"reason": "agent_failed"}, + ), + ) + events.append(event.model_dump()) + await queue.put(event) + finally: + message_create = MessageCreate( + id=config["run_id"], + thread_id=thread_id, + user_message_id=user_message.id, + model_uri=model_uri, + role=MessageRole.ASSISTANT, + content=assistant_message, + artifacts=artifacts or None, + events=events or None, + status=status or MessageStatus.ERROR, + ) + try: + async with sessionmaker() as session: + database = AsyncDatabase(session) + message = await database.create_message(message_create) + message_id = str(message.id) + error_details = None + except Exception: + logger.exception( + f"Failed to persist assistant message for run {config['run_id']}:" + ) + message_id = None + error_details = {"reason": "persistence_failed"} + await queue.put( + StreamEvent( + type="complete", + data=EventData(run_id=message_id, error_details=error_details), + ) + ) diff --git a/app/api/streaming/schemas.py b/app/api/streaming/schemas.py index e90b048..b8a7e52 100644 --- a/app/api/streaming/schemas.py +++ b/app/api/streaming/schemas.py @@ -1,4 +1,3 @@ -import uuid from typing import Any, Literal from pydantic import BaseModel, JsonValue @@ -29,7 +28,7 @@ class ToolOutput(BaseModel): class EventData(BaseModel): - run_id: uuid.UUID | None = None + run_id: str | None = None content: str | None = None tool_calls: list[ToolCall] | None = None tool_outputs: list[ToolOutput] | None = None diff --git a/app/api/streaming/stream.py b/app/api/streaming/stream.py index 264fb5e..a1e5c32 100644 --- a/app/api/streaming/stream.py +++ b/app/api/streaming/stream.py @@ -1,266 +1,24 @@ -import json -from typing import Any, AsyncIterator +import asyncio +from typing import AsyncIterator -from langchain_core.messages import AIMessage, ToolMessage -from langgraph.graph.state import CompiledStateGraph -from loguru import logger +from app.api.streaming.schemas import StreamEvent -from app.api.schemas import ConfigDict -from app.api.streaming.schemas import EventData, StreamEvent, ToolCall, ToolOutput -from app.api.streaming.security import sanitize_markdown_links -from app.db.database import AsyncDatabase -from app.db.models import Message, MessageCreate, MessageRole, MessageStatus +async def stream_events(queue: asyncio.Queue[StreamEvent]) -> AsyncIterator[str]: + """Forward events from the queue as SSE strings until `complete`. -class ErrorMessage: - UNEXPECTED = ( - "Ops, algo deu errado! Ocorreu um erro inesperado. Por favor, tente novamente. " - "Se o problema persistir, avise-nos. Obrigado pela paciência!" - ) - - MODEL_CALL_LIMIT_REACHED = ( - "Ops, essa pergunta gerou um raciocínio muito longo e não consegui chegar a uma conclusão. " - "Por favor, tente ser mais específico ou divida sua pergunta em partes menores." - ) - - -def _truncate_json( - json_string: str, max_list_len: int = 10, max_str_len: int = 300 -) -> str: - """Iteratively truncates a serialized JSON object by shortening lists and strings - and adding human-readable placeholders. - - Note: - This function only processes JSON objects (dictionaries). If the serialized JSON - represents any other type, the original JSON string will be returned unchanged. - - Args: - json_string (str): The serialized JSON to process. - max_list_len (int, optional): The max number of items to keep in a list. Defaults to 10. - max_str_len (int, optional): The max length for any single string. Defaults to 300. - - Returns: - str: The truncated, formatted, and serialized JSON object. - """ - try: - data = json.loads(json_string) - except json.JSONDecodeError: - return json_string - - if not isinstance(data, dict): - return json_string - - stack = [data] - - while stack: - current_node = stack.pop() - - if isinstance(current_node, dict): - items_to_process = current_node.items() - else: - items_to_process = enumerate(current_node) - - for key_or_idx, item in items_to_process: - if isinstance(item, str): - if len(item) > max_str_len: - truncated_str = ( - item[:max_str_len] - + f"... ({len(item) - max_str_len} more characters)" - ) - current_node[key_or_idx] = truncated_str - - elif isinstance(item, list): - if len(item) > max_list_len: - original_len = len(item) - del item[max_list_len:] - item.append(f"... ({original_len - max_list_len} more items)") - stack.append(item) - - elif isinstance(item, dict): - stack.append(item) - - return json.dumps(data, ensure_ascii=False, indent=2) - - -def _parse_thinking(message: AIMessage) -> str | None: - """Parse thinking content from an AI message. - - Some models (e.g., Gemini 3) return `message.content` as a list of typed blocks, - which may include `{"type": "thinking", "thinking": "..."}` entries. When - `content` is a plain string, no thinking is available. - - Args: - message (AIMessage): The AI message from where to parse the thinking. - - Returns: - str | None: The concatenated thinking text, or None if no thinking blocks exist. - """ - if isinstance(message.content, str): - return None - - blocks = [ - block - for block in message.content - if isinstance(block, dict) - and block.get("type") == "thinking" - and isinstance(block.get("thinking"), str) - ] - - thinking = "".join(block["thinking"] for block in blocks) - - return thinking or None - - -def _process_chunk(chunk: dict[str, Any]) -> StreamEvent | None: - """Process a streaming chunk from a react agent workflow into a standardized StreamEvent. - - Args: - chunk (dict[str, Any]): Raw chunk from agent workflow. - Only processes "agent" and "tools" nodes. - - Returns: - StreamEvent | None: Structured event or None if the chunk is ignored: - - "tool_call" for agent messages with tool calls - - "tool_output" for tool execution results - - "final_answer" for agent messages without tool calls - - None for ignored chunks - """ - if "model" in chunk: - ai_messages: list[AIMessage] = chunk["model"]["messages"] - - # If no messages are returned, the model returned an empty response - # with no tool calls. This also counts as a final (but empty) answer. - if not ai_messages: - return StreamEvent(type="final_answer", data=EventData(content="")) - - message = ai_messages[0] - - if message.tool_calls: - event_type = "tool_call" - tool_calls = [ - ToolCall( - id=tool_call["id"], name=tool_call["name"], args=tool_call["args"] - ) - for tool_call in message.tool_calls - ] - content = _parse_thinking(message) or message.text - else: - event_type = "final_answer" - tool_calls = None - content = sanitize_markdown_links(message.text) - - event_data = EventData(content=content, tool_calls=tool_calls) - - return StreamEvent(type=event_type, data=event_data) - elif "tools" in chunk: - updates = chunk["tools"] - - # single tool call - if isinstance(updates, dict): - tool_messages: list[ToolMessage] = updates["messages"] - - # multiple parallel tool calls - elif isinstance(updates, list): - tool_messages: list[ToolMessage] = [ - update["messages"][0] for update in updates if "messages" in update - ] - - # defensive handling (langgraph should only output dicts and lists) - else: - tool_messages = [] - - tool_outputs = [ - ToolOutput( - status=message.status, - tool_call_id=message.tool_call_id, - tool_name=message.name, - content=_truncate_json(message.content), - artifact=message.artifact, - ) - for message in tool_messages - ] - - return StreamEvent( - type="tool_output", data=EventData(tool_outputs=tool_outputs) - ) - elif "ModelCallLimitMiddleware.before_model" in chunk: - # before_model runs on every model iteration; only the limit-exceeded - # path sets jump_to="end", so check that rather than the key's presence. - update = chunk["ModelCallLimitMiddleware.before_model"] or {} - if update.get("jump_to") == "end": - event_data = EventData( - content=ErrorMessage.MODEL_CALL_LIMIT_REACHED, tool_calls=None - ) - return StreamEvent(type="final_answer", data=event_data) - return None - - -async def stream_response( - database: AsyncDatabase, - agent: CompiledStateGraph, - user_message: Message, - config: ConfigDict, - thread_id: str, - model_uri: str, -) -> AsyncIterator[str]: - """Stream ReAct Agent's execution progress. + The producer is responsible for ensuring exactly one `complete` event is emitted per run. + This generator does no accumulation and no persistence: cancelling it on client disconnect + is safe and has no side effects on the in-flight run. Args: - message (str): User's input message. - config (ConfigDict): Configuration for the agent's execution. - thread_id (str): Unique identifier for the conversation thread. + queue (asyncio.Queue[StreamEvent]): Events queue. Yields: - Iterator[str]: JSON string containing the streaming status and the current step data. + AsyncIterator[str]: Iterator of serialized events. """ - events = [] - artifacts = [] - assistant_message = "" - status = MessageStatus.SUCCESS - - try: - async for mode, chunk in agent.astream( # pragma: no cover - input={"messages": [{"role": "user", "content": user_message.content}]}, - config=config, - stream_mode=["updates", "values"], - ): - if mode == "values": - continue - - event = _process_chunk(chunk) - - if event is not None: - if event.type == "tool_output": - for output in event.data.tool_outputs: - if output.artifact: - artifacts.append(output.artifact) - - elif event.type == "final_answer": - assistant_message = event.data.content - status = MessageStatus.SUCCESS - - events.append(event.model_dump()) - yield event.to_sse() - except Exception: - logger.exception(f"Unexpected error responding message {config['run_id']}:") - assistant_message = ErrorMessage.UNEXPECTED - status = MessageStatus.ERROR - yield StreamEvent( - type="error", data=EventData(error_details={"message": assistant_message}) - ).to_sse() - - message_create = MessageCreate( - id=config["run_id"], - thread_id=thread_id, - user_message_id=user_message.id, - model_uri=model_uri, - role=MessageRole.ASSISTANT, - content=assistant_message, - artifacts=artifacts or None, - events=events or None, - status=status, - ) - - message = await database.create_message(message_create) - - yield StreamEvent(type="complete", data=EventData(run_id=message.id)).to_sse() + while True: + event = await queue.get() + yield event.to_sse() + if event.type == "complete": + return diff --git a/app/db/models.py b/app/db/models.py index 61905b6..a989971 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -34,13 +34,15 @@ class Thread(ThreadCreate, table=True): # == Message Models == # ============================================================================== class MessageRole(str, Enum): - ASSISTANT = "ASSISTANT" USER = "USER" + ASSISTANT = "ASSISTANT" class MessageStatus(str, Enum): ERROR = "ERROR" SUCCESS = "SUCCESS" + INTERRUPTED = "INTERRUPTED" + MODEL_CALL_LIMIT = "MODEL_CALL_LIMIT" class MessageCreate(SQLModel): diff --git a/app/main.py b/app/main.py index b2d38ee..aefa38b 100644 --- a/app/main.py +++ b/app/main.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from datetime import date @@ -90,13 +91,39 @@ async def lifespan(app: FastAPI): # pragma: no cover ) app.state.agent = agent + app.state.running_runs = {} yield - await engine.dispose() + # Drain in-flight agent runs so they have a chance to persist + # NOTE: The connection pool must be open to persist checkpoints + running = list(app.state.running_runs.values()) + if running: + logger.info(f"Draining {len(running)} in-flight agent runs") + done, pending = await asyncio.wait( + running, timeout=settings.SHUTDOWN_DRAIN_TIMEOUT_SECONDS + ) + if pending: + logger.warning( + f"{len(pending)} agent runs did not finish within " + f"{settings.SHUTDOWN_DRAIN_TIMEOUT_SECONDS}s; cancelling..." + ) + for task in pending: + task.cancel() + # Wait for cancelled tasks + await asyncio.wait(pending) + logger.info( + f"Drain complete: {len(done)} finished, {len(pending)} cancelled" + ) except Exception: logger.exception("Lifespan failed:") raise + finally: + await engine.dispose() + # No-ops when LOG_ENQUEUE=False; wait for enqueued messages + # and remove handlers to avoid leaked-semaphore warnings when LOG_ENQUEUE=True. + await logger.complete() + logger.remove() app = FastAPI(lifespan=lifespan) diff --git a/app/settings.py b/app/settings.py index 63484c7..d9e1232 100644 --- a/app/settings.py +++ b/app/settings.py @@ -149,6 +149,18 @@ def GOOGLE_CREDENTIALS(self) -> Credentials: # pragma: no cover ), ) + # ============================================================ + # == Shutdown settings == + # ============================================================ + SHUTDOWN_DRAIN_TIMEOUT_SECONDS: float = Field( + default=25.0, + description=( + "Seconds to wait for in-flight agent runs to finish during lifespan " + "shutdown before cancelling. Must stay below the pod's " + "terminationGracePeriodSeconds (k8s default 30s)." + ), + ) + model_config = SettingsConfigDict(env_file=".env", extra="ignore", frozen=True) diff --git a/tests/app/api/dependencies/test_agent.py b/tests/app/api/dependencies/test_agent.py index 0248192..1091a7d 100644 --- a/tests/app/api/dependencies/test_agent.py +++ b/tests/app/api/dependencies/test_agent.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from app.api.dependencies.agent import get_agent +from app.api.dependencies.agent import get_agent, get_running_runs class TestGetAgent: @@ -16,3 +16,19 @@ def test_returns_agent_from_app_state(self): result = get_agent(mock_request) assert result is mock_agent + + +class TestGetRunningRuns: + """Tests for get_running_runs dependency.""" + + def test_returns_running_runs_from_app_state(self): + """Test that get_running_runs returns the running runs dictionary + from request.app.state.running_runs""" + mock_running_runs = {} + mock_request = MagicMock() + + mock_request.app.state.running_runs = mock_running_runs + + result = get_running_runs(mock_request) + + assert result is mock_running_runs diff --git a/tests/app/api/routers/test_chatbot.py b/tests/app/api/routers/test_chatbot.py index 1e5e56e..84607c8 100644 --- a/tests/app/api/routers/test_chatbot.py +++ b/tests/app/api/routers/test_chatbot.py @@ -82,12 +82,13 @@ def access_token(user_id: str) -> str: @pytest.fixture -def client(database: AsyncDatabase): +def client(database: AsyncDatabase, monkeypatch: pytest.MonkeyPatch): """Test client with mocked agent, database, and feedback sender.""" @asynccontextmanager async def mock_lifespan(app: FastAPI): app.state.agent = MockAgent() + app.state.running_runs = {} yield def get_database_override(): @@ -100,6 +101,20 @@ def get_feedback_sender_override(): app.dependency_overrides[get_feedback_sender] = get_feedback_sender_override app.router.lifespan_context = mock_lifespan + # The producer (run_agent) builds its own AsyncDatabase via sessionmaker() + # so the session lifetime matches the producer rather than the request. + # In tests, route it to the same testcontainers DB the dependency uses. + @asynccontextmanager + async def fake_sessionmaker(): + yield None + + monkeypatch.setattr( + "app.api.streaming.agent_runner.sessionmaker", fake_sessionmaker + ) + monkeypatch.setattr( + "app.api.streaming.agent_runner.AsyncDatabase", lambda session: database + ) + with TestClient(app) as client: yield client @@ -411,6 +426,41 @@ def test_send_message_unauthorized(self, client: TestClient, thread: Thread): ) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_send_message_persists_when_client_disconnects( + self, client: TestClient, access_token: str, thread: Thread + ): + """Test producer survives client disconnect and still writes the assistant message. + + Reads one SSE event then aborts the stream. The background producer + must still complete (run the agent, persist the message). + """ + import time + + with client.stream( + method="POST", + url=f"/api/v1/chatbot/threads/{thread.id}/messages", + json={"content": "Hello, chatbot!"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) as response: + assert response.status_code == status.HTTP_201_CREATED + # Read only the first SSE event, then drop the connection + for _ in response.iter_lines(): + break + + # Give the background producer task time to finish persisting + time.sleep(1.0) + + # Both the user message (written by the route handler) and the + # assistant message (written by run_agent's finally block) must be + # in the DB, regardless of whether the consumer was still listening. + messages = client.get( + url=f"/api/v1/chatbot/threads/{thread.id}/messages", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert messages.status_code == status.HTTP_200_OK + assert len(messages.json()) == 2 + class TestUpsertFeedbackEndpoint: """Tests for PUT /api/v1/chatbot/messages/{message_id}/feedback""" diff --git a/tests/app/api/streaming/test_agent_runner.py b/tests/app/api/streaming/test_agent_runner.py new file mode 100644 index 0000000..0cae5a1 --- /dev/null +++ b/tests/app/api/streaming/test_agent_runner.py @@ -0,0 +1,793 @@ +import asyncio +import json +import uuid +from contextlib import asynccontextmanager +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.api.schemas import ConfigDict +from app.api.streaming.agent_runner import ( + ErrorMessage, + _parse_thinking, + _process_chunk, + _truncate_json, + run_agent, +) +from app.api.streaming.schemas import StreamEvent +from app.db.models import Message, MessageCreate, MessageRole, MessageStatus + +MODEL_URI = "mock-model" + + +@pytest.fixture +def thread_id() -> str: + return str(uuid.uuid4()) + + +@pytest.fixture +def config(thread_id: str) -> ConfigDict: + return ConfigDict( + run_id=str(uuid.uuid4()), + configurable={"thread_id": thread_id}, + ) + + +@pytest.fixture +def mock_user_message(thread_id: str) -> Message: + return Message( + thread_id=thread_id, + model_uri=MODEL_URI, + role=MessageRole.USER, + content="Mock user message", + status=MessageStatus.SUCCESS, + ) + + +@pytest.fixture +def mock_database( + monkeypatch: pytest.MonkeyPatch, + config: ConfigDict, + mock_user_message: Message, +) -> MagicMock: + """Patch AsyncDatabase + sessionmaker so run_agent uses a mock instead + of opening a real DB connection.""" + db = MagicMock() + + db.create_message = AsyncMock( + return_value=Message( + id=config["run_id"], + thread_id=mock_user_message.thread_id, + user_message_id=mock_user_message.id, + model_uri=mock_user_message.model_uri, + role=MessageRole.ASSISTANT, + content="Mock assistant message", + status=MessageStatus.SUCCESS, + ) + ) + + @asynccontextmanager + async def mock_sessionmaker(): + yield # session is unused because AsyncDatabase is itself mocked + + monkeypatch.setattr( + "app.api.streaming.agent_runner.sessionmaker", mock_sessionmaker + ) + + monkeypatch.setattr( + "app.api.streaming.agent_runner.AsyncDatabase", lambda session: db + ) + + return db + + +class TestTruncateJSON: + """Tests for _truncate_json function.""" + + STR_MAX_LEN = 300 + STR_LONG_LEN = 400 + STR_REMAINING = STR_LONG_LEN - STR_MAX_LEN + + LIST_MAX_LEN = 10 + LIST_LONG_LEN = 15 + LIST_REMAINING = LIST_LONG_LEN - LIST_MAX_LEN + + @staticmethod + def _format_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + def test_truncate_json_long_string(self): + """Test that long strings are truncated with a remaining count.""" + data = {"long_string": "a" * self.STR_LONG_LEN} + json_string = json.dumps(data) + truncated = _truncate_json(json_string, max_str_len=self.STR_MAX_LEN) + expected_str = ( + "a" * self.STR_MAX_LEN + f"... ({self.STR_REMAINING} more characters)" + ) + expected_json = self._format_json({"long_string": expected_str}) + assert truncated == expected_json + + def test_truncate_json_long_list(self): + """Test that long lists are truncated with a remaining count.""" + data = {"long_list": list(range(self.LIST_LONG_LEN))} + json_string = json.dumps(data) + truncated = _truncate_json(json_string, max_list_len=self.LIST_MAX_LEN) + expected_list = list(range(self.LIST_MAX_LEN)) + [ + f"... ({self.LIST_REMAINING} more items)" + ] + expected_json = self._format_json({"long_list": expected_list}) + assert truncated == expected_json + + def test_truncate_json_nested(self): + """Test that nested structures have both strings and lists truncated.""" + data = { + "short_string": "a" * 100, + "nested_list": [ + { + "short_string": "b" * 100, + "long_string": "c" * self.STR_LONG_LEN, + "int": 1, + "float": 1.0, + } + for _ in range(self.LIST_LONG_LEN) + ], + "nested_dict": {"long_string": "d" * self.STR_LONG_LEN}, + } + json_string = json.dumps(data) + truncated = _truncate_json( + json_string, max_list_len=self.LIST_MAX_LEN, max_str_len=self.STR_MAX_LEN + ) + expected_data = { + "short_string": "a" * 100, + "nested_list": [ + { + "short_string": "b" * 100, + "long_string": "c" * self.STR_MAX_LEN + + f"... ({self.STR_REMAINING} more characters)", + "int": 1, + "float": 1.0, + } + for _ in range(self.LIST_MAX_LEN) + ] + + [f"... ({self.LIST_REMAINING} more items)"], + "nested_dict": { + "long_string": "d" * self.STR_MAX_LEN + + f"... ({self.STR_REMAINING} more characters)" + }, + } + expected_json = self._format_json(expected_data) + assert truncated == expected_json + + def test_truncate_json_not_dict(self): + """Test that non-dict JSON is returned as-is.""" + data = list(range(self.LIST_LONG_LEN)) + json_string = json.dumps(data) + truncated = _truncate_json(json_string) + assert truncated == json_string + + def test_truncate_json_not_needed(self): + """Test that short strings and lists are not truncated.""" + data = { + "short_string": "hello", + "short_list": [1, 2, 3], + } + json_string = json.dumps(data) + expected_json = self._format_json(data) + assert _truncate_json(json_string) == expected_json + + def test_truncate_json_invalid(self): + """Test that invalid JSON is returned as-is.""" + invalid_json_string = '{"key": "value"' + assert _truncate_json(invalid_json_string) == invalid_json_string + + +class TestParseThinking: + """Tests for _parse_thinking function.""" + + def test_string_content_returns_none(self): + """Test that plain string content returns None.""" + message = AIMessage(content="Hello, world!") + assert _parse_thinking(message) is None + + def test_single_thinking_block(self): + """Test extraction of a single thinking block.""" + message = AIMessage( + content=[ + {"type": "thinking", "thinking": "Let me reason about this."}, + {"type": "text", "text": "Here is my answer."}, + ] + ) + assert _parse_thinking(message) == "Let me reason about this." + + def test_multiple_thinking_blocks_are_concatenated(self): + """Test that multiple thinking blocks are concatenated.""" + message = AIMessage( + content=[ + {"type": "thinking", "thinking": "First thought. "}, + {"type": "text", "text": "Some text."}, + {"type": "thinking", "thinking": "Second thought."}, + ] + ) + assert _parse_thinking(message) == "First thought. Second thought." + + def test_no_thinking_blocks_returns_none(self): + """Test that content with no thinking blocks returns None.""" + message = AIMessage( + content=[ + {"type": "text", "text": "Just text."}, + ] + ) + assert _parse_thinking(message) is None + + def test_empty_thinking_block_returns_none(self): + """Test that an empty thinking string returns None.""" + message = AIMessage( + content=[ + {"type": "thinking", "thinking": ""}, + ] + ) + assert _parse_thinking(message) is None + + def test_non_dict_blocks_are_skipped(self): + """Test that non-dict items in content are safely skipped.""" + message = AIMessage( + content=[ + "plain string block", + {"type": "thinking", "thinking": "Actual thinking."}, + ] + ) + assert _parse_thinking(message) == "Actual thinking." + + +class TestProcessChunk: + """Tests for _process_chunk function.""" + + def test_agent_chunk_with_tool_calls(self): + """Test agent chunk with tool calls returns tool_call event.""" + chunk = { + "model": { + "messages": [ + AIMessage( + content="Let me search for that.", + tool_calls=[ + { + "id": "call_123", + "name": "search", + "args": {"query": "foo"}, + } + ], + ) + ] + } + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_call" + assert event.data.run_id is None + assert event.data.tool_outputs is None + assert event.data.error_details is None + assert event.data.content == "Let me search for that." + assert len(event.data.tool_calls) == 1 + + tool_call = event.data.tool_calls[0] + + assert tool_call.id == "call_123" + assert tool_call.name == "search" + assert tool_call.args == {"query": "foo"} + + def test_agent_chunk_with_multiple_tool_calls(self): + """Test agent chunk with multiple parallel tool calls.""" + chunk = { + "model": { + "messages": [ + AIMessage( + content="I'll search both.", + tool_calls=[ + { + "id": "call_1", + "name": "search", + "args": {"query": "foo"}, + }, + {"id": "call_2", "name": "lookup", "args": {"id": "123"}}, + ], + ) + ] + } + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_call" + assert len(event.data.tool_calls) == 2 + assert event.data.tool_calls[0].name == "search" + assert event.data.tool_calls[1].name == "lookup" + + def test_agent_chunk_final_answer(self): + """Test agent chunk without tool calls returns final_answer event.""" + chunk = {"model": {"messages": [AIMessage(content="Here is your answer.")]}} + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "final_answer" + assert event.data.run_id is None + assert event.data.tool_calls is None + assert event.data.tool_outputs is None + assert event.data.error_details is None + assert event.data.content == "Here is your answer." + + def test_agent_chunk_empty_messages(self): + """Test agent chunk with empty messages list returns empty final_answer.""" + chunk = {"model": {"messages": []}} + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "final_answer" + assert event.data.content == "" + + def test_tools_chunk_single_tool(self): + """Test tools chunk with single tool output (dict format).""" + chunk = { + "tools": { + "messages": [ + ToolMessage( + content='{"result": "found"}', + tool_call_id="call_123", + name="search", + status="success", + artifact={"url": "http://example.com"}, + ) + ] + } + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_output" + assert len(event.data.tool_outputs) == 1 + + tool_output = event.data.tool_outputs[0] + + assert tool_output.status == "success" + assert tool_output.tool_call_id == "call_123" + assert tool_output.tool_name == "search" + assert tool_output.content == '{\n "result": "found"\n}' + assert tool_output.artifact == {"url": "http://example.com"} + assert tool_output.metadata is None + + def test_tools_chunk_multiple_parallel_tools(self): + """Test tools chunk with multiple parallel tool outputs (list format).""" + chunk = { + "tools": [ + { + "messages": [ + ToolMessage( + content='{"data": "foo"}', + tool_call_id="call_1", + name="search", + status="success", + ) + ] + }, + { + "messages": [ + ToolMessage( + content='{"data": "bar"}', + tool_call_id="call_2", + name="lookup", + status="success", + ) + ] + }, + ] + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_output" + assert len(event.data.tool_outputs) == 2 + assert event.data.tool_outputs[0].tool_call_id == "call_1" + assert event.data.tool_outputs[1].tool_call_id == "call_2" + + def test_tools_chunk_with_error_status(self): + """Test tools chunk with error status.""" + chunk = { + "tools": { + "messages": [ + ToolMessage( + content="Tool execution failed", + tool_call_id="call_123", + name="search", + status="error", + ) + ] + } + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_output" + assert event.data.tool_outputs[0].status == "error" + + def test_tools_chunk_unexpected_format(self): + """Test tools chunk with unexpected format returns empty tool_outputs.""" + chunk = {"tools": "unexpected string"} + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "tool_output" + assert event.data.tool_outputs == [] + + def test_model_call_limit_triggered_chunk(self): + """Test before_model chunk with jump_to=end yields final_answer event.""" + chunk = { + "ModelCallLimitMiddleware.before_model": { + "jump_to": "end", + "messages": [AIMessage(content="Model call limits exceeded: ...")], + } + } + + event = _process_chunk(chunk) + + assert event is not None + assert event.type == "final_answer" + assert event.data.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED + + def test_model_call_limit_passthrough_chunk_returns_none(self): + """Test before_model passthrough chunk (None payload) returns None.""" + chunk = {"ModelCallLimitMiddleware.before_model": None} + assert _process_chunk(chunk) is None + + def test_model_call_limit_passthrough_chunk_no_jump_returns_none(self): + """Test before_model chunk without jump_to returns None.""" + chunk = {"ModelCallLimitMiddleware.before_model": {"messages": []}} + assert _process_chunk(chunk) is None + + def test_unrecognized_chunk_returns_none(self): + """Test unrecognized chunk returns None.""" + chunk = {"unknown_node": {"data": "something"}} + event = _process_chunk(chunk) + assert event is None + + def test_empty_chunk_returns_none(self): + """Test empty chunk returns None.""" + chunk = {} + event = _process_chunk(chunk) + assert event is None + + +class TestRunAgent: + """Tests for run_agent function.""" + + async def _drain(self, queue: asyncio.Queue[StreamEvent]) -> list[StreamEvent]: + events = [] + while True: + event = await queue.get() + events.append(event) + if event.type == "complete": + return events + + async def test_emits_events_and_persists_success( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Test happy path emits events and persists success message.""" + agent = MagicMock() + + async def astream(*args, **kwargs): + yield ( + "updates", + {"model": {"messages": [AIMessage(content="Final answer")]}}, + ) + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + await run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + + events = await self._drain(queue) + assert [e.type for e in events] == ["final_answer", "complete"] + assert events[0].data.content == "Final answer" + assert events[-1].data.run_id == config["run_id"] + + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.SUCCESS + assert message.content == "Final answer" + + async def test_unexpected_exception_persists_error_row( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Test unexpected exceptions are handled properly.""" + agent = MagicMock() + + async def astream(*args, **kwargs): + raise RuntimeError("error") + yield # make this a generator + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + await run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + + events = await self._drain(queue) + assert [e.type for e in events] == ["error", "complete"] + assert events[0].data.content == ErrorMessage.UNEXPECTED + assert events[0].data.error_details == {"reason": "agent_failed"} + + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.ERROR + assert message.content == ErrorMessage.UNEXPECTED + + async def test_model_call_limit_persists_with_dedicated_status( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Test ModelCallLimit error is handled properly.""" + agent = MagicMock() + + async def astream(*args, **kwargs): + yield ( + "updates", + {"ModelCallLimitMiddleware.before_model": {"jump_to": "end"}}, + ) + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + await run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + + events = await self._drain(queue) + assert [e.type for e in events] == ["final_answer", "complete"] + assert events[0].data.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED + assert events[-1].data.run_id == config["run_id"] + + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.MODEL_CALL_LIMIT + assert message.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED + + async def test_complete_still_emitted_when_db_write_fails( + self, + monkeypatch: pytest.MonkeyPatch, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """If `database.create_message` raises, the consumer must still + receive a `complete` event - otherwise it hangs on `queue.get()`. + """ + db = MagicMock() + db.create_message = AsyncMock(side_effect=RuntimeError("db down")) + + @asynccontextmanager + async def mock_sessionmaker(): + yield None + + monkeypatch.setattr( + "app.api.streaming.agent_runner.sessionmaker", mock_sessionmaker + ) + + monkeypatch.setattr( + "app.api.streaming.agent_runner.AsyncDatabase", lambda session: db + ) + + agent = MagicMock() + + async def astream(*args, **kwargs): + yield ( + "updates", + {"model": {"messages": [AIMessage(content="Final answer")]}}, + ) + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + await run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + + events = await self._drain(queue) + assert [e.type for e in events] == ["final_answer", "complete"] + assert events[0].data.content == "Final answer" + + complete = events[-1] + assert complete.type == "complete" + assert complete.data.run_id is None + assert complete.data.error_details == {"reason": "persistence_failed"} + assert "db down" not in str(complete.data.error_details) + + async def test_consumer_cancel_does_not_cancel_producer( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Test the producer task runs to completion even if no one drains the queue.""" + agent = MagicMock() + + async def astream(*args, **kwargs): + yield ( + "updates", + {"model": {"messages": [AIMessage(content="Final answer")]}}, + ) + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + task = asyncio.create_task( + run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + ) + + # Simulate the consumer never attaching: just await the producer + await asyncio.wait_for(task, timeout=2.0) + + # Producer persisted the message regardless of consumer presence + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.SUCCESS + assert message.content == "Final answer" + + # The complete event is sitting in the queue waiting + events = await self._drain(queue) + assert events[-1].type == "complete" + + async def test_cancellation_before_final_answer_persists_interrupted( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Cancelling the producer before any final_answer persists row with + INTERRUPTED content + INTERRUPTED status, and re-raises CancelledError.""" + agent = MagicMock() + started = asyncio.Event() + + async def astream(*args, **kwargs): + started.set() + # Block until cancelled — never yields a chunk + await asyncio.sleep(60) + yield # make this a generator + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + task = asyncio.create_task( + run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + ) + + # Wait until the producer reach the await inside agent.astream + await started.wait() + + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert task.cancelled() + + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.INTERRUPTED + assert message.content == ErrorMessage.INTERRUPTED + + async def test_cancellation_after_final_answer_preserves_success( + self, + mock_database: MagicMock, + mock_user_message: Message, + config: ConfigDict, + thread_id: str, + ): + """Cancelling the producer after a final_answer has been processed + preserves the SUCCESS status — the CancelledError branch only sets + INTERRUPTED when no status has been observed yet.""" + agent = MagicMock() + processed = asyncio.Event() + + async def astream(*args, **kwargs): + yield ( + "updates", + {"model": {"messages": [AIMessage(content="Final answer")]}}, + ) + # Reached when the producer asks for the next chunk, i.e., after it + # has set status=SUCCESS for the final_answer above. + processed.set() + await asyncio.sleep(60) + + agent.astream = astream + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + + task = asyncio.create_task( + run_agent( + agent=agent, + config=config, + thread_id=thread_id, + user_message=mock_user_message, + model_uri=MODEL_URI, + queue=queue, + ) + ) + + # Wait until the final_answer event is processed + await processed.wait() + + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert task.cancelled() + + mock_database.create_message.assert_called_once() + message = mock_database.create_message.call_args[0][0] + assert isinstance(message, MessageCreate) + assert message.status == MessageStatus.SUCCESS + assert message.content == "Final answer" diff --git a/tests/app/api/streaming/test_stream.py b/tests/app/api/streaming/test_stream.py index b6ab95f..b952ab7 100644 --- a/tests/app/api/streaming/test_stream.py +++ b/tests/app/api/streaming/test_stream.py @@ -1,678 +1,36 @@ -import json +import asyncio import uuid -from typing import Any, AsyncIterator -from unittest.mock import AsyncMock, MagicMock -import pytest -from langchain_core.messages import AIMessage, ToolMessage +from app.api.streaming.schemas import EventData, StreamEvent +from app.api.streaming.stream import stream_events -from app.api.schemas import ConfigDict -from app.api.streaming.stream import ( - ErrorMessage, - _parse_thinking, - _process_chunk, - _truncate_json, - stream_response, -) -from app.db.models import Message, MessageRole, MessageStatus +class TestStreamEvents: + async def test_stream_until_complete_then_exits(self): + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + run_id = str(uuid.uuid4()) + await queue.put(StreamEvent(type="final_answer", data=EventData(content="ok"))) + await queue.put(StreamEvent(type="complete", data=EventData(run_id=run_id))) -class TestTruncateJSON: - """Tests for _truncate_json function.""" - - STR_MAX_LEN = 300 - STR_LONG_LEN = 400 - STR_REMAINING = STR_LONG_LEN - STR_MAX_LEN - - LIST_MAX_LEN = 10 - LIST_LONG_LEN = 15 - LIST_REMAINING = LIST_LONG_LEN - LIST_MAX_LEN - - @staticmethod - def _format_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, indent=2) - - def test_truncate_json_long_string(self): - """Test that long strings are truncated with a remaining count.""" - data = {"long_string": "a" * self.STR_LONG_LEN} - json_string = json.dumps(data) - truncated = _truncate_json(json_string, max_str_len=self.STR_MAX_LEN) - expected_str = ( - "a" * self.STR_MAX_LEN + f"... ({self.STR_REMAINING} more characters)" - ) - expected_json = self._format_json({"long_string": expected_str}) - assert truncated == expected_json - - def test_truncate_json_long_list(self): - """Test that long lists are truncated with a remaining count.""" - data = {"long_list": list(range(self.LIST_LONG_LEN))} - json_string = json.dumps(data) - truncated = _truncate_json(json_string, max_list_len=self.LIST_MAX_LEN) - expected_list = list(range(self.LIST_MAX_LEN)) + [ - f"... ({self.LIST_REMAINING} more items)" - ] - expected_json = self._format_json({"long_list": expected_list}) - assert truncated == expected_json - - def test_truncate_json_nested(self): - """Test that nested structures have both strings and lists truncated.""" - data = { - "short_string": "a" * 100, - "nested_list": [ - { - "short_string": "b" * 100, - "long_string": "c" * self.STR_LONG_LEN, - "int": 1, - "float": 1.0, - } - for _ in range(self.LIST_LONG_LEN) - ], - "nested_dict": {"long_string": "d" * self.STR_LONG_LEN}, - } - json_string = json.dumps(data) - truncated = _truncate_json( - json_string, max_list_len=self.LIST_MAX_LEN, max_str_len=self.STR_MAX_LEN - ) - expected_data = { - "short_string": "a" * 100, - "nested_list": [ - { - "short_string": "b" * 100, - "long_string": "c" * self.STR_MAX_LEN - + f"... ({self.STR_REMAINING} more characters)", - "int": 1, - "float": 1.0, - } - for _ in range(self.LIST_MAX_LEN) - ] - + [f"... ({self.LIST_REMAINING} more items)"], - "nested_dict": { - "long_string": "d" * self.STR_MAX_LEN - + f"... ({self.STR_REMAINING} more characters)" - }, - } - expected_json = self._format_json(expected_data) - assert truncated == expected_json - - def test_truncate_json_not_dict(self): - """Test that non-dict JSON is returned as-is.""" - data = list(range(self.LIST_LONG_LEN)) - json_string = json.dumps(data) - truncated = _truncate_json(json_string) - assert truncated == json_string - - def test_truncate_json_not_needed(self): - """Test that short strings and lists are not truncated.""" - data = { - "short_string": "hello", - "short_list": [1, 2, 3], - } - json_string = json.dumps(data) - expected_json = self._format_json(data) - assert _truncate_json(json_string) == expected_json - - def test_truncate_json_invalid(self): - """Test that invalid JSON is returned as-is.""" - invalid_json_string = '{"key": "value"' - assert _truncate_json(invalid_json_string) == invalid_json_string - - -class TestParseThinking: - """Tests for _parse_thinking function.""" - - def test_string_content_returns_none(self): - """Test that plain string content returns None.""" - message = AIMessage(content="Hello, world!") - assert _parse_thinking(message) is None - - def test_single_thinking_block(self): - """Test extraction of a single thinking block.""" - message = AIMessage( - content=[ - {"type": "thinking", "thinking": "Let me reason about this."}, - {"type": "text", "text": "Here is my answer."}, - ] - ) - assert _parse_thinking(message) == "Let me reason about this." - - def test_multiple_thinking_blocks_are_concatenated(self): - """Test that multiple thinking blocks are concatenated.""" - message = AIMessage( - content=[ - {"type": "thinking", "thinking": "First thought. "}, - {"type": "text", "text": "Some text."}, - {"type": "thinking", "thinking": "Second thought."}, - ] - ) - assert _parse_thinking(message) == "First thought. Second thought." - - def test_no_thinking_blocks_returns_none(self): - """Test that content with no thinking blocks returns None.""" - message = AIMessage( - content=[ - {"type": "text", "text": "Just text."}, - ] - ) - assert _parse_thinking(message) is None - - def test_empty_thinking_block_returns_none(self): - """Test that an empty thinking string returns None.""" - message = AIMessage( - content=[ - {"type": "thinking", "thinking": ""}, - ] - ) - assert _parse_thinking(message) is None - - def test_non_dict_blocks_are_skipped(self): - """Test that non-dict items in content are safely skipped.""" - message = AIMessage( - content=[ - "plain string block", - {"type": "thinking", "thinking": "Actual thinking."}, - ] - ) - assert _parse_thinking(message) == "Actual thinking." - - -class TestProcessChunk: - """Tests for _process_chunk function.""" - - def test_agent_chunk_with_tool_calls(self): - """Test agent chunk with tool calls returns tool_call event.""" - chunk = { - "model": { - "messages": [ - AIMessage( - content="Let me search for that.", - tool_calls=[ - { - "id": "call_123", - "name": "search", - "args": {"query": "foo"}, - } - ], - ) - ] - } - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_call" - assert event.data.run_id is None - assert event.data.tool_outputs is None - assert event.data.error_details is None - assert event.data.content == "Let me search for that." - assert len(event.data.tool_calls) == 1 - - tool_call = event.data.tool_calls[0] - - assert tool_call.id == "call_123" - assert tool_call.name == "search" - assert tool_call.args == {"query": "foo"} - - def test_agent_chunk_with_multiple_tool_calls(self): - """Test agent chunk with multiple parallel tool calls.""" - chunk = { - "model": { - "messages": [ - AIMessage( - content="I'll search both.", - tool_calls=[ - { - "id": "call_1", - "name": "search", - "args": {"query": "foo"}, - }, - {"id": "call_2", "name": "lookup", "args": {"id": "123"}}, - ], - ) - ] - } - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_call" - assert len(event.data.tool_calls) == 2 - assert event.data.tool_calls[0].name == "search" - assert event.data.tool_calls[1].name == "lookup" - - def test_agent_chunk_final_answer(self): - """Test agent chunk without tool calls returns final_answer event.""" - chunk = {"model": {"messages": [AIMessage(content="Here is your answer.")]}} - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "final_answer" - assert event.data.run_id is None - assert event.data.tool_calls is None - assert event.data.tool_outputs is None - assert event.data.error_details is None - assert event.data.content == "Here is your answer." - - def test_agent_chunk_empty_messages(self): - """Test agent chunk with empty messages list returns empty final_answer.""" - chunk = {"model": {"messages": []}} - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "final_answer" - assert event.data.content == "" - - def test_tools_chunk_single_tool(self): - """Test tools chunk with single tool output (dict format).""" - chunk = { - "tools": { - "messages": [ - ToolMessage( - content='{"result": "found"}', - tool_call_id="call_123", - name="search", - status="success", - artifact={"url": "http://example.com"}, - ) - ] - } - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_output" - assert len(event.data.tool_outputs) == 1 - - tool_output = event.data.tool_outputs[0] - - assert tool_output.status == "success" - assert tool_output.tool_call_id == "call_123" - assert tool_output.tool_name == "search" - assert tool_output.content == '{\n "result": "found"\n}' - assert tool_output.artifact == {"url": "http://example.com"} - assert tool_output.metadata is None - - def test_tools_chunk_multiple_parallel_tools(self): - """Test tools chunk with multiple parallel tool outputs (list format).""" - chunk = { - "tools": [ - { - "messages": [ - ToolMessage( - content='{"data": "foo"}', - tool_call_id="call_1", - name="search", - status="success", - ) - ] - }, - { - "messages": [ - ToolMessage( - content='{"data": "bar"}', - tool_call_id="call_2", - name="lookup", - status="success", - ) - ] - }, - ] - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_output" - assert len(event.data.tool_outputs) == 2 - assert event.data.tool_outputs[0].tool_call_id == "call_1" - assert event.data.tool_outputs[1].tool_call_id == "call_2" - - def test_tools_chunk_with_error_status(self): - """Test tools chunk with error status.""" - chunk = { - "tools": { - "messages": [ - ToolMessage( - content="Tool execution failed", - tool_call_id="call_123", - name="search", - status="error", - ) - ] - } - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_output" - assert event.data.tool_outputs[0].status == "error" - - def test_tools_chunk_unexpected_format(self): - """Test tools chunk with unexpected format returns empty tool_outputs.""" - chunk = {"tools": "unexpected string"} - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "tool_output" - assert event.data.tool_outputs == [] - - def test_model_call_limit_triggered_chunk(self): - """Test before_model chunk with jump_to=end yields final_answer event.""" - chunk = { - "ModelCallLimitMiddleware.before_model": { - "jump_to": "end", - "messages": [AIMessage(content="Model call limits exceeded: ...")], - } - } - - event = _process_chunk(chunk) - - assert event is not None - assert event.type == "final_answer" - assert event.data.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED - - def test_model_call_limit_passthrough_chunk_returns_none(self): - """Test before_model passthrough chunk (None payload) returns None.""" - chunk = {"ModelCallLimitMiddleware.before_model": None} - assert _process_chunk(chunk) is None - - def test_model_call_limit_passthrough_chunk_no_jump_returns_none(self): - """Test before_model chunk without jump_to returns None.""" - chunk = {"ModelCallLimitMiddleware.before_model": {"messages": []}} - assert _process_chunk(chunk) is None - - def test_unrecognized_chunk_returns_none(self): - """Test unrecognized chunk returns None.""" - chunk = {"unknown_node": {"data": "something"}} - event = _process_chunk(chunk) - assert event is None - - def test_empty_chunk_returns_none(self): - """Test empty chunk returns None.""" - chunk = {} - event = _process_chunk(chunk) - assert event is None - - -class TestStreamResponse: - """Tests for stream_response function.""" - - @pytest.fixture - def mock_thread_id(self) -> str: - """Generate a random thread ID.""" - return str(uuid.uuid4()) - - @pytest.fixture - def mock_model_uri(self) -> str: - """Return a mock model URI.""" - return "mock-model" - - @pytest.fixture - def mock_config(self, mock_thread_id: str) -> ConfigDict: - """Create a mock config dict.""" - return { - "run_id": uuid.uuid4(), - "configurable": {"thread_id": mock_thread_id}, - } - - @pytest.fixture - def mock_user_message(self, mock_thread_id: str, mock_model_uri: str) -> Message: - """Create a mock user message.""" - return Message( - thread_id=mock_thread_id, - model_uri=mock_model_uri, - role=MessageRole.USER, - content="Hello", - status=MessageStatus.SUCCESS, - ) - - @pytest.fixture - def mock_database(self, mock_config: ConfigDict, mock_user_message: Message): - """Create a mock database with stubbed create_message.""" - db = MagicMock() - - created_message = Message( - id=mock_config["run_id"], - thread_id=mock_user_message.thread_id, - user_message_id=mock_user_message.id, - model_uri=mock_user_message.model_uri, - role=MessageRole.ASSISTANT, - content="Mock response", - status=MessageStatus.SUCCESS, - ) - - db.create_message = AsyncMock(return_value=created_message) - - return db - - @staticmethod - async def _collect_events(async_gen: AsyncIterator[str]) -> list[str]: - """Helper to collect all events from async generator.""" events = [] - async for event in async_gen: - events.append(event) - return events - - async def test_stream_response_happy_path( - self, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test successful streaming: skips 'values' mode, collects artifacts, yields all events.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - yield ( - "updates", - { - "model": { - "messages": [ - AIMessage( - content="Let me search.", - tool_calls=[ - {"id": "call_1", "name": "search", "args": {}}, - {"id": "call_2", "name": "lookup", "args": {}}, - ], - ) - ] - } - }, - ) - yield "values", {"messages": ["msg1"]} - yield ( - "updates", - {"unknown_node": {}}, - ) # Unrecognized chunk, _process_chunk returns None - yield ( - "updates", - { - "tools": [ - { - "messages": [ - ToolMessage( - content='{"result": "data"}', - tool_call_id="call_1", - name="search", - status="success", - artifact={"url": "http://example.com"}, - ) - ] - }, - { - "messages": [ - ToolMessage( - content='{"id": "123"}', - tool_call_id="call_2", - name="lookup", - status="success", - artifact=None, # No artifact - ) - ] - }, - ] - }, - ) - yield "values", {"messages": ["msg1", "msg2"]} - yield ( - "updates", - {"model": {"messages": [AIMessage(content="Here is your answer.")]}}, - ) - yield "values", {"messages": ["msg1", "msg2", "msg3"]} - - mock_agent.astream = mock_astream - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) - ) - - assert len(events) == 4 - assert '"type":"tool_call"' in events[0] - assert '"type":"tool_output"' in events[1] - assert '"type":"final_answer"' in events[2] - assert '"type":"complete"' in events[3] - - mock_database.create_message.assert_called_once() - call_args = mock_database.create_message.call_args[0][0] - assert call_args.artifacts == [ - {"url": "http://example.com"} - ] # Only one artifact collected - - async def test_stream_response_generic_exception( - self, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test generic exception yields error event.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - raise Exception("Something went wrong") - yield # Makes this an async generator - - mock_agent.astream = mock_astream - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) - ) - - assert len(events) == 2 - assert '"type":"error"' in events[0] - assert ErrorMessage.UNEXPECTED in events[0] - assert '"type":"complete"' in events[1] - - call_args = mock_database.create_message.call_args[0][0] - assert call_args.status == MessageStatus.ERROR - assert call_args.content == ErrorMessage.UNEXPECTED - - async def test_stream_response_model_call_limit_reached( - self, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test ModelCallLimitMiddleware sets graceful message without error status.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - yield ( - "updates", - { - "ModelCallLimitMiddleware.before_model": { - "jump_to": "end", - "messages": [ - AIMessage(content="Model call limits exceeded: ...") - ], - } - }, - ) - - mock_agent.astream = mock_astream - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) - ) + async for sse in stream_events(queue): + events.append(sse) assert len(events) == 2 assert '"type":"final_answer"' in events[0] - assert ErrorMessage.MODEL_CALL_LIMIT_REACHED in events[0] assert '"type":"complete"' in events[1] + assert run_id in events[1] - call_args = mock_database.create_message.call_args[0][0] - assert call_args.status == MessageStatus.SUCCESS - assert call_args.content == ErrorMessage.MODEL_CALL_LIMIT_REACHED - - async def test_stream_response_before_model_passthrough_chunks_ignored( - self, - mock_database, - mock_user_message, - mock_config, - mock_thread_id, - mock_model_uri, - ): - """Test before_model passthrough chunks do not produce final_answer events.""" - mock_agent = MagicMock() - - async def mock_astream(*args, **kwargs): - yield ("updates", {"ModelCallLimitMiddleware.before_model": None}) - yield ( - "updates", - {"model": {"messages": [AIMessage(content="Real answer.")]}}, - ) - - mock_agent.astream = mock_astream - - events = await self._collect_events( - stream_response( - database=mock_database, - agent=mock_agent, - user_message=mock_user_message, - config=mock_config, - thread_id=mock_thread_id, - model_uri=mock_model_uri, - ) + async def test_stream_error_event_before_complete(self): + queue: asyncio.Queue[StreamEvent] = asyncio.Queue() + await queue.put(StreamEvent(type="error", data=EventData(content="bad"))) + await queue.put( + StreamEvent(type="complete", data=EventData(run_id=str(uuid.uuid4()))) ) - assert len(events) == 2 - assert '"type":"final_answer"' in events[0] - assert "Real answer." in events[0] - assert ErrorMessage.MODEL_CALL_LIMIT_REACHED not in events[0] - assert '"type":"complete"' in events[1] + events = [] + async for sse in stream_events(queue): + events.append(sse) - call_args = mock_database.create_message.call_args[0][0] - assert call_args.status == MessageStatus.SUCCESS - assert call_args.content == "Real answer." + assert '"type":"error"' in events[0] + assert '"type":"complete"' in events[1] diff --git a/tests/conftest.py b/tests/conftest.py index ed3f264..a3fc38a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,7 +127,7 @@ def assistant_message_create(user_message: Message) -> MessageCreate: return MessageCreate( thread_id=user_message.thread_id, user_message_id=user_message.id, - model_uri="mock_model", + model_uri="mock-model", role=MessageRole.ASSISTANT, content="Mock assistant message", artifacts=[{"mock_artifact": "artifact"}], @@ -162,7 +162,7 @@ async def factory() -> tuple[Message, Message]: assistant_message_create = MessageCreate( thread_id=user_message.thread_id, user_message_id=user_message.id, - model_uri="mock_model", + model_uri="mock-model", role=MessageRole.ASSISTANT, content="Mock assistant message", artifacts=[{"mock_artifact": "artifact"}],