Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions src/openlayer/lib/integrations/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self, **kwargs: Any) -> None:
self.root_steps: set[UUID] = set() # Track which steps are root
# Track standalone traces (consistent with async handler)
self._traces_by_root: Dict[UUID, traces.Trace] = {}
# Map every active run_id (root or nested) to the trace its step lives in,
# so callbacks can find the right trace to update metadata on even when
# nothing has set tracer._current_trace.
self._run_id_to_trace: Dict[UUID, traces.Trace] = {}
# Extract inference_id from kwargs if provided
self._inference_id = kwargs.get("inference_id")
# Extract metadata_transformer from kwargs if provided
Expand Down Expand Up @@ -119,6 +123,10 @@ def _start_step(
# This step has a parent - add it as a nested step
parent_step = self.steps[parent_run_id]
parent_step.add_nested_step(step)
# Inherit the parent's owning trace so context lookups work.
parent_trace = self._run_id_to_trace.get(parent_run_id)
if parent_trace is not None:
self._run_id_to_trace[run_id] = parent_trace
else:
# This is a root step - check if we're in an existing trace context
current_step = tracer.get_current_step()
Expand All @@ -127,15 +135,19 @@ def _start_step(
if current_step is not None:
# We're inside an existing step context - add as nested
current_step.add_nested_step(step)
if current_trace is not None:
self._run_id_to_trace[run_id] = current_trace
elif current_trace is not None:
# Existing trace but no current step - add to trace
current_trace.add_step(step)
self._run_id_to_trace[run_id] = current_trace
# Don't track in _traces_by_root since we're using external trace
else:
# No existing context - create standalone trace
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace
self._run_id_to_trace[run_id] = trace

# Track root steps (those without parent_run_id)
if parent_run_id is None:
Expand All @@ -160,6 +172,7 @@ def _end_step(
return

step = self.steps.pop(run_id)
self._run_id_to_trace.pop(run_id, None)
is_root_step = run_id in self.root_steps

if is_root_step:
Expand Down Expand Up @@ -676,7 +689,7 @@ def _handle_chain_end(
context_list.append(str(doc))

if context_list:
current_trace = tracer.get_current_trace()
current_trace = self._find_trace(run_id)
if current_trace:
current_trace.update_metadata(context=context_list)

Expand Down Expand Up @@ -748,9 +761,40 @@ def _handle_tool_start(
arguments=tool_input,
)

def _find_trace(self, run_id: UUID) -> Optional["traces.Trace"]:
"""Find the trace that owns this run_id, preferring an external context.

Falls back to the handler's own per-run_id map (covers pure-callback
usage where nothing has set tracer._current_trace).
"""
external = tracer.get_current_trace()
if external is not None:
return external
return self._run_id_to_trace.get(run_id)

def _extract_tool_output(self, output: Any) -> tuple:
"""Split a tool result into (content, artifact).

Handles three shapes LangChain may pass to on_tool_end:
- ToolMessage with .content/.artifact (response_format="content_and_artifact")
- 2-tuple (content, artifact) (older LangChain versions)
- plain string / anything else (no artifact)
"""
if HAVE_LANGCHAIN and isinstance(output, langchain_schema.ToolMessage):
return output.content, getattr(output, "artifact", None)
if isinstance(output, tuple) and len(output) == 2:
return output[0], output[1]
return output, None

def _is_document_list(self, value: Any) -> bool:
"""True iff value is a non-empty sequence of objects with page_content."""
if not isinstance(value, (list, tuple)) or not value:
return False
return all(getattr(item, "page_content", None) is not None for item in value)

def _handle_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand All @@ -760,10 +804,22 @@ def _handle_tool_end(
if run_id not in self.steps:
return

content, artifact = self._extract_tool_output(output)

if artifact is not None:
step = self.steps[run_id]
step.metadata["artifact"] = artifact

if self._is_document_list(artifact):
doc_contents = [doc.page_content for doc in artifact]
current_trace = self._find_trace(run_id)
if current_trace:
current_trace.update_metadata(context=doc_contents)

self._end_step(
run_id=run_id,
parent_run_id=parent_run_id,
output=output,
output=content,
)

def _handle_tool_error(
Expand Down Expand Up @@ -875,7 +931,7 @@ def _handle_retriever_end(
else:
doc_contents.append(str(doc))

current_trace = tracer.get_current_trace()
current_trace = self._find_trace(run_id)
if current_trace:
current_trace.update_metadata(context=doc_contents)

Expand Down Expand Up @@ -1171,6 +1227,9 @@ def _start_step(
# This step has a parent - add as nested step
parent_step = self.steps[parent_run_id]
parent_step.add_nested_step(step)
parent_trace = self._run_id_to_trace.get(parent_run_id)
if parent_trace is not None:
self._run_id_to_trace[run_id] = parent_trace
else:
# Check if we're in an existing trace context via ContextVars
current_step = tracer.get_current_step()
Expand All @@ -1179,23 +1238,28 @@ def _start_step(
if current_step is not None:
# We're inside an existing step context - add as nested
current_step.add_nested_step(step)
if current_trace is not None:
self._run_id_to_trace[run_id] = current_trace
elif current_trace is not None:
# Have trace but no current step
# If it's an external trace, we should NOT add at root - external system will integrate
# If it's a ContextVar trace with no current step, add to trace
if not self._has_external_trace:
# ContextVar-detected trace - add directly
current_trace.add_step(step)
self._run_id_to_trace[run_id] = current_trace
else:
# External trace without current step - create temp standalone for later integration
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace
self._run_id_to_trace[run_id] = trace
else:
# No existing context - create standalone trace
trace = traces.Trace()
trace.add_step(step)
self._traces_by_root[run_id] = trace
self._run_id_to_trace[run_id] = trace

# Track root steps
if parent_run_id is None:
Expand All @@ -1220,6 +1284,7 @@ def _end_step(
return

step = self.steps.pop(run_id)
self._run_id_to_trace.pop(run_id, None)
is_root_step = run_id in self.root_steps

if is_root_step:
Expand Down