diff --git a/src/openlayer/lib/integrations/langchain_callback.py b/src/openlayer/lib/integrations/langchain_callback.py index 2066b84c..25dcc7d7 100644 --- a/src/openlayer/lib/integrations/langchain_callback.py +++ b/src/openlayer/lib/integrations/langchain_callback.py @@ -81,6 +81,10 @@ def __init__(self, **kwargs: Any) -> None: self.root_steps: set[UUID] = set() # Track which steps are root # Track standalone traces (consistent with async handler) self._traces_by_root: Dict[UUID, traces.Trace] = {} + # Map every active run_id (root or nested) to the trace its step lives in, + # so callbacks can find the right trace to update metadata on even when + # nothing has set tracer._current_trace. + self._run_id_to_trace: Dict[UUID, traces.Trace] = {} # Extract inference_id from kwargs if provided self._inference_id = kwargs.get("inference_id") # Extract metadata_transformer from kwargs if provided @@ -119,6 +123,10 @@ def _start_step( # This step has a parent - add it as a nested step parent_step = self.steps[parent_run_id] parent_step.add_nested_step(step) + # Inherit the parent's owning trace so context lookups work. + parent_trace = self._run_id_to_trace.get(parent_run_id) + if parent_trace is not None: + self._run_id_to_trace[run_id] = parent_trace else: # This is a root step - check if we're in an existing trace context current_step = tracer.get_current_step() @@ -127,15 +135,19 @@ def _start_step( if current_step is not None: # We're inside an existing step context - add as nested current_step.add_nested_step(step) + if current_trace is not None: + self._run_id_to_trace[run_id] = current_trace elif current_trace is not None: # Existing trace but no current step - add to trace current_trace.add_step(step) + self._run_id_to_trace[run_id] = current_trace # Don't track in _traces_by_root since we're using external trace else: # No existing context - create standalone trace trace = traces.Trace() trace.add_step(step) self._traces_by_root[run_id] = trace + self._run_id_to_trace[run_id] = trace # Track root steps (those without parent_run_id) if parent_run_id is None: @@ -160,6 +172,7 @@ def _end_step( return step = self.steps.pop(run_id) + self._run_id_to_trace.pop(run_id, None) is_root_step = run_id in self.root_steps if is_root_step: @@ -676,7 +689,7 @@ def _handle_chain_end( context_list.append(str(doc)) if context_list: - current_trace = tracer.get_current_trace() + current_trace = self._find_trace(run_id) if current_trace: current_trace.update_metadata(context=context_list) @@ -748,9 +761,40 @@ def _handle_tool_start( arguments=tool_input, ) + def _find_trace(self, run_id: UUID) -> Optional["traces.Trace"]: + """Find the trace that owns this run_id, preferring an external context. + + Falls back to the handler's own per-run_id map (covers pure-callback + usage where nothing has set tracer._current_trace). + """ + external = tracer.get_current_trace() + if external is not None: + return external + return self._run_id_to_trace.get(run_id) + + def _extract_tool_output(self, output: Any) -> tuple: + """Split a tool result into (content, artifact). + + Handles three shapes LangChain may pass to on_tool_end: + - ToolMessage with .content/.artifact (response_format="content_and_artifact") + - 2-tuple (content, artifact) (older LangChain versions) + - plain string / anything else (no artifact) + """ + if HAVE_LANGCHAIN and isinstance(output, langchain_schema.ToolMessage): + return output.content, getattr(output, "artifact", None) + if isinstance(output, tuple) and len(output) == 2: + return output[0], output[1] + return output, None + + def _is_document_list(self, value: Any) -> bool: + """True iff value is a non-empty sequence of objects with page_content.""" + if not isinstance(value, (list, tuple)) or not value: + return False + return all(getattr(item, "page_content", None) is not None for item in value) + def _handle_tool_end( self, - output: str, + output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -760,10 +804,22 @@ def _handle_tool_end( if run_id not in self.steps: return + content, artifact = self._extract_tool_output(output) + + if artifact is not None: + step = self.steps[run_id] + step.metadata["artifact"] = artifact + + if self._is_document_list(artifact): + doc_contents = [doc.page_content for doc in artifact] + current_trace = self._find_trace(run_id) + if current_trace: + current_trace.update_metadata(context=doc_contents) + self._end_step( run_id=run_id, parent_run_id=parent_run_id, - output=output, + output=content, ) def _handle_tool_error( @@ -875,7 +931,7 @@ def _handle_retriever_end( else: doc_contents.append(str(doc)) - current_trace = tracer.get_current_trace() + current_trace = self._find_trace(run_id) if current_trace: current_trace.update_metadata(context=doc_contents) @@ -1171,6 +1227,9 @@ def _start_step( # This step has a parent - add as nested step parent_step = self.steps[parent_run_id] parent_step.add_nested_step(step) + parent_trace = self._run_id_to_trace.get(parent_run_id) + if parent_trace is not None: + self._run_id_to_trace[run_id] = parent_trace else: # Check if we're in an existing trace context via ContextVars current_step = tracer.get_current_step() @@ -1179,6 +1238,8 @@ def _start_step( if current_step is not None: # We're inside an existing step context - add as nested current_step.add_nested_step(step) + if current_trace is not None: + self._run_id_to_trace[run_id] = current_trace elif current_trace is not None: # Have trace but no current step # If it's an external trace, we should NOT add at root - external system will integrate @@ -1186,16 +1247,19 @@ def _start_step( if not self._has_external_trace: # ContextVar-detected trace - add directly current_trace.add_step(step) + self._run_id_to_trace[run_id] = current_trace else: # External trace without current step - create temp standalone for later integration trace = traces.Trace() trace.add_step(step) self._traces_by_root[run_id] = trace + self._run_id_to_trace[run_id] = trace else: # No existing context - create standalone trace trace = traces.Trace() trace.add_step(step) self._traces_by_root[run_id] = trace + self._run_id_to_trace[run_id] = trace # Track root steps if parent_run_id is None: @@ -1220,6 +1284,7 @@ def _end_step( return step = self.steps.pop(run_id) + self._run_id_to_trace.pop(run_id, None) is_root_step = run_id in self.root_steps if is_root_step: