Skip to content

Commit 87ce7d1

Browse files
authored
fix: fix langchain parent spans (#296)
1 parent bd16166 commit 87ce7d1

File tree

2 files changed

+66
-31
lines changed

2 files changed

+66
-31
lines changed

langfuse/callback.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,31 @@ def on_chain_start(
234234
kwargs=kwargs,
235235
version=self.version,
236236
)
237+
if parent_run_id is None:
238+
if self.root_span is None:
239+
self.runs[run_id] = self.trace.span(
240+
id=self.next_span_id,
241+
trace_id=self.trace.id,
242+
name=serialized.get(
243+
"name", serialized.get("id", ["<unknown>"])[-1]
244+
),
245+
metadata=self.__join_tags_and_metadata(tags, metadata),
246+
input=inputs,
247+
version=self.version,
248+
)
249+
250+
else:
251+
self.runs[run_id] = self.root_span.span(
252+
id=self.next_span_id,
253+
trace_id=self.trace.id,
254+
name=serialized.get(
255+
"name", serialized.get("id", ["<unknown>"])[-1]
256+
),
257+
metadata=self.__join_tags_and_metadata(tags, metadata),
258+
input=inputs,
259+
version=self.version,
260+
)
261+
237262
if parent_run_id is not None:
238263
self.runs[run_id] = self.runs[parent_run_id].span(
239264
id=self.next_span_id,
@@ -293,40 +318,16 @@ def __generate_trace_and_parent(
293318

294319
self.trace = trace
295320

296-
self.runs[run_id] = self.trace.span(
297-
id=self.next_span_id,
298-
trace_id=self.trace.id,
299-
name=class_name,
300-
metadata=self.__join_tags_and_metadata(tags, metadata),
301-
input=inputs,
302-
version=self.version,
303-
)
304-
return
305-
306-
# if we are at root, and root was provided by user,
307-
# create a span for the trace or span provided
308-
if self.langfuse is None and parent_run_id is None:
309-
self.runs[run_id] = (
310-
self.trace.span(
321+
if parent_run_id is not None and parent_run_id in self.runs:
322+
self.runs[run_id] = self.trace.span(
311323
id=self.next_span_id,
312324
trace_id=self.trace.id,
313325
name=class_name,
314326
metadata=self.__join_tags_and_metadata(tags, metadata),
315327
input=inputs,
316328
version=self.version,
317329
)
318-
if self.root_span is None
319-
else self.root_span.span(
320-
id=self.next_span_id,
321-
trace_id=self.trace.id,
322-
name=class_name,
323-
metadata=self.__join_tags_and_metadata(tags, metadata),
324-
input=inputs,
325-
version=self.version,
326-
)
327-
)
328330

329-
self.next_span_id = None
330331
return
331332

332333
except Exception as e:

tests/test_langchain.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_mistral():
134134
from langchain_mistralai.chat_models import ChatMistralAI
135135

136136
api = get_api()
137-
callback = CallbackHandler(debug=True)
137+
callback = CallbackHandler(debug=False)
138138

139139
chat = ChatMistralAI(model="mistral-small", callbacks=[callback])
140140
messages = [HumanMessage(content="say a brief hello")]
@@ -242,7 +242,41 @@ def test_basic_chat_openai():
242242
trace = api.trace.get(trace_id)
243243

244244
assert trace.id == trace_id
245-
assert len(trace.observations) == 2
245+
assert len(trace.observations) == 1
246+
247+
248+
def test_basic_chat_openai_based_on_trace():
249+
from langchain.schema import HumanMessage, SystemMessage
250+
251+
trace_id = create_uuid()
252+
253+
langfuse = Langfuse(debug=False)
254+
trace = langfuse.trace(id=trace_id)
255+
256+
callback = trace.get_langchain_handler()
257+
258+
chat = ChatOpenAI(temperature=0)
259+
260+
messages = [
261+
SystemMessage(
262+
content="You are a helpful assistant that translates English to French."
263+
),
264+
HumanMessage(
265+
content="Translate this sentence from English to French. I love programming."
266+
),
267+
]
268+
269+
chat(messages, callbacks=[callback])
270+
callback.flush()
271+
272+
trace_id = callback.get_trace_id()
273+
274+
api = get_api()
275+
276+
trace = api.trace.get(trace_id)
277+
278+
assert trace.id == trace_id
279+
assert len(trace.observations) == 1
246280

247281

248282
def test_callback_from_trace_simple_chain():
@@ -630,7 +664,7 @@ def test_callback_simple_openai():
630664

631665
trace = api.trace.get(trace_id)
632666

633-
assert len(trace.observations) == 2
667+
assert len(trace.observations) == 1
634668
assert trace.input == trace.observations[0].input
635669
for observation in trace.observations:
636670
if observation.type == "GENERATION":
@@ -670,7 +704,7 @@ def test_callback_multiple_invocations_on_different_traces():
670704
{"trace": trace_one, "expected_trace_id": trace_id_one},
671705
{"trace": trace_two, "expected_trace_id": trace_id_two},
672706
]:
673-
assert len(test_data["trace"].observations) == 2
707+
assert len(test_data["trace"].observations) == 1
674708
assert test_data["trace"].id == test_data["expected_trace_id"]
675709
for observation in test_data["trace"].observations:
676710
if observation.type == "GENERATION":
@@ -688,7 +722,7 @@ def test_callback_simple_openai_streaming():
688722
api_wrapper = LangfuseAPI()
689723
handler = CallbackHandler(debug=False)
690724

691-
llm = OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), streaming=True)
725+
llm = OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), streaming=False)
692726

693727
text = "What would be a good company name for a company that makes laptops?"
694728

0 commit comments

Comments
 (0)