Skip to content

Commit c00240f

Browse files
authored
fix(streaming): handle None stop tokens in streaming handler (#1685)
1 parent 7882262 commit c00240f

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

nemoguardrails/streaming.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import asyncio
1717
import logging
1818
import warnings
19-
from typing import Any, AsyncIterator, Dict, Optional, Union
19+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
2020

2121
from nemoguardrails.utils import new_uuid
2222

@@ -83,12 +83,19 @@ def __init__(
8383
# If set, the chunk will be piped to the specified handler rather than added to the queue or printed
8484
self.pipe_to = None
8585

86-
# The stop chunks
87-
self.stop = []
86+
self._stop = []
8887

8988
self.include_metadata = include_metadata
9089
self.current_metadata = {}
9190

91+
@property
92+
def stop(self) -> List[str]:
93+
return self._stop
94+
95+
@stop.setter
96+
def stop(self, value: Optional[List[str]]) -> None:
97+
self._stop = value or []
98+
9299
def set_pattern(self, prefix: Optional[str] = None, suffix: Optional[str] = None):
93100
"""Sets the pattern that is expected.
94101

tests/test_actions_llm_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_infer_provider_from_module,
3030
_store_reasoning_traces,
3131
_store_tool_calls,
32+
_stream_llm_call,
3233
llm_call,
3334
)
3435
from nemoguardrails.context import reasoning_trace_var, tool_calls_var
@@ -712,3 +713,48 @@ async def test_llm_call_does_not_mutate_llm_params(self):
712713
original_params = {"max_tokens": 100}
713714
await llm_call(mock_llm, "prompt", stop=["User:"], llm_params=original_params)
714715
assert original_params == {"max_tokens": 100}
716+
717+
718+
async def _empty_astream(*args, **kwargs):
719+
return
720+
yield
721+
722+
723+
class _FakeLLM:
724+
def __init__(self, stop=None, kwargs=None):
725+
self.stop = stop
726+
if kwargs is not None:
727+
self.kwargs = kwargs
728+
self.astream = _empty_astream
729+
730+
731+
class TestStreamLlmCallStopCoercion:
732+
@pytest.mark.asyncio
733+
async def test_llm_stop_attr_none_coerced_to_list(self):
734+
from nemoguardrails.streaming import StreamingHandler
735+
736+
llm = _FakeLLM(stop=None)
737+
handler = StreamingHandler()
738+
await _stream_llm_call(llm, "prompt", handler)
739+
740+
assert handler.stop == []
741+
742+
@pytest.mark.asyncio
743+
async def test_llm_kwargs_stop_none_coerced_to_list(self):
744+
from nemoguardrails.streaming import StreamingHandler
745+
746+
llm = _FakeLLM(kwargs={"stop": None})
747+
handler = StreamingHandler()
748+
await _stream_llm_call(llm, "prompt", handler)
749+
750+
assert handler.stop == []
751+
752+
@pytest.mark.asyncio
753+
async def test_llm_with_valid_stop_preserved(self):
754+
from nemoguardrails.streaming import StreamingHandler
755+
756+
llm = _FakeLLM(stop=["User:"])
757+
handler = StreamingHandler()
758+
await _stream_llm_call(llm, "prompt", handler)
759+
760+
assert handler.stop == ["User:"]

tests/test_streaming_handler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,40 @@ async def test_multiple_stop_tokens():
435435
await consumer.cancel()
436436

437437

438+
@pytest.mark.asyncio
439+
async def test_stop_none_does_not_raise():
440+
handler = StreamingHandler()
441+
consumer = StreamingConsumer(handler)
442+
443+
try:
444+
handler.stop = None
445+
await handler.push_chunk("Hello world")
446+
await handler.push_chunk(END_OF_STREAM)
447+
448+
chunks = await consumer.get_chunks()
449+
assert chunks == ["Hello world"]
450+
assert handler.completion == "Hello world"
451+
finally:
452+
await consumer.cancel()
453+
454+
455+
@pytest.mark.asyncio
456+
async def test_stop_none_with_pattern():
457+
await _test_pattern_case(
458+
prefix='Bot message: "',
459+
suffix='"',
460+
stop=None,
461+
chunks=[
462+
"Bot",
463+
" message: ",
464+
'"',
465+
"This is a message",
466+
'"',
467+
],
468+
final_chunks=["This is a message"],
469+
)
470+
471+
438472
@pytest.mark.asyncio
439473
async def test_enable_print_functionality():
440474
"""Test the enable_print functionality."""

0 commit comments

Comments
 (0)