Skip to content

Commit 2231475

Browse files
authored
fix(llm): filter stop parameter for OpenAI reasoning models (#1653)
* Fix OpenAI GPT-5 and other reasoning models that do not allow stop tokens as parameters in API requests * Ensure llm_params in llm_call is not modified, also added test.
1 parent c9228e9 commit 2231475

File tree

5 files changed

+96
-36
lines changed

5 files changed

+96
-36
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -144,27 +144,41 @@ def _infer_model_name(llm: BaseLanguageModel):
144144
def _filter_params_for_openai_reasoning_models(llm: BaseLanguageModel, llm_params: Optional[dict]) -> Optional[dict]:
145145
"""Filter out unsupported parameters for OpenAI reasoning models.
146146
147-
OpenAI reasoning models (o1, o3, gpt-5 excluding gpt-5-chat) only support
148-
temperature=1. When using .bind() with other temperature values, the API
149-
returns an error. This function removes the temperature parameter for these
150-
models to allow the API default to apply.
147+
OpenAI reasoning models (o1, o3, gpt-5 excluding gpt-5-chat) do only allow
148+
specific parameters (e.g. temperature, which is always fixed at 1, or stop).
149+
When using .bind() with different values for these parameters, the API
150+
returns an error. This function removes the unsupported parameters for specific
151+
OpenAI reasoning models to ensure correct functionality for the API calls.
151152
152-
See: https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/langchain_openai/chat_models/base.py
153+
See also: https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/langchain_openai/chat_models/base.py
154+
155+
Stop not supported as a parameter in the following models (as of Jan 26):
156+
gpt5+ (only gpt-5-chat-latest works), o3, o3-pro (but o3-mini works), o4-mini
153157
"""
154-
if not llm_params or "temperature" not in llm_params:
158+
if not llm_params or ("temperature" not in llm_params and "stop" not in llm_params):
155159
return llm_params
156160

157161
model_name = _infer_model_name(llm).lower()
158162

159-
is_openai_reasoning_model = (
163+
# Models that do not support temperature as a param, or changing its default value
164+
is_temperature_not_supported = (
160165
model_name.startswith("o1")
161166
or model_name.startswith("o3")
162167
or (model_name.startswith("gpt-5") and "chat" not in model_name)
163168
)
169+
# Models that do not support stop as a param
170+
is_stop_not_supported = (
171+
(model_name.startswith("o3") and "o3-mini" not in model_name)
172+
or model_name.startswith("o4-mini")
173+
or (model_name.startswith("gpt-5") and "gpt-5-chat" not in model_name)
174+
)
164175

165-
if is_openai_reasoning_model:
176+
if is_temperature_not_supported or is_stop_not_supported:
166177
filtered = llm_params.copy()
167-
filtered.pop("temperature", None)
178+
if is_temperature_not_supported:
179+
filtered.pop("temperature", None)
180+
if is_stop_not_supported:
181+
filtered.pop("stop", None)
168182
return filtered
169183

170184
return llm_params
@@ -202,18 +216,25 @@ async def llm_call(
202216
raise LLMCallException(ValueError("No LLM provided to llm_call()"))
203217
_setup_llm_call_info(llm, model_name, model_provider)
204218

205-
filtered_params = _filter_params_for_openai_reasoning_models(llm, llm_params)
219+
llm_params_with_stop: Optional[dict]
220+
if stop:
221+
llm_params_with_stop = llm_params.copy() if llm_params else {}
222+
llm_params_with_stop["stop"] = stop
223+
else:
224+
llm_params_with_stop = llm_params
225+
226+
filtered_params = _filter_params_for_openai_reasoning_models(llm, llm_params_with_stop)
206227
generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind(**filtered_params) if filtered_params else llm
207228

208229
if streaming_handler:
209-
return await _stream_llm_call(generation_llm, prompt, streaming_handler, stop)
230+
return await _stream_llm_call(generation_llm, prompt, streaming_handler)
210231
else:
211232
all_callbacks = _prepare_callbacks(custom_callback_handlers)
212233

213234
if isinstance(prompt, str):
214-
response = await _invoke_with_string_prompt(generation_llm, prompt, all_callbacks, stop)
235+
response = await _invoke_with_string_prompt(generation_llm, prompt, all_callbacks)
215236
else:
216-
response = await _invoke_with_message_list(generation_llm, prompt, all_callbacks, stop)
237+
response = await _invoke_with_message_list(generation_llm, prompt, all_callbacks)
217238

218239
_store_reasoning_traces(response)
219240
_store_tool_calls(response)
@@ -225,7 +246,6 @@ async def _stream_llm_call(
225246
llm: Union[BaseLanguageModel, Runnable],
226247
prompt: Union[str, List[dict]],
227248
handler: "StreamingHandler",
228-
stop: Optional[List[str]],
229249
) -> str:
230250
"""Stream LLM response using astream().
231251
@@ -237,11 +257,17 @@ async def _stream_llm_call(
237257
else:
238258
messages = prompt
239259

240-
handler.stop = stop or []
260+
stop = []
261+
if hasattr(llm, "kwargs"):
262+
current_params = getattr(llm, "kwargs", {})
263+
stop = current_params.get("stop", [])
264+
if not stop:
265+
stop = getattr(llm, "stop", [])
266+
handler.stop = stop
241267
accumulated_metadata: Dict[str, Any] = {}
242268

243269
try:
244-
async for chunk in llm.astream(messages, stop=stop, config=RunnableConfig(callbacks=logging_callbacks)):
270+
async for chunk in llm.astream(messages, config=RunnableConfig(callbacks=logging_callbacks)):
245271
if hasattr(chunk, "content"):
246272
content = chunk.content
247273
else:
@@ -351,11 +377,10 @@ async def _invoke_with_string_prompt(
351377
llm: Union[BaseLanguageModel, Runnable],
352378
prompt: str,
353379
callbacks: BaseCallbackManager,
354-
stop: Optional[List[str]],
355380
):
356381
"""Invoke LLM with string prompt."""
357382
try:
358-
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks), stop=stop)
383+
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks))
359384
except Exception as e:
360385
_raise_llm_call_exception(e, llm)
361386

@@ -364,13 +389,12 @@ async def _invoke_with_message_list(
364389
llm: Union[BaseLanguageModel, Runnable],
365390
prompt: List[dict],
366391
callbacks: BaseCallbackManager,
367-
stop: Optional[List[str]],
368392
):
369393
"""Invoke LLM with message list after converting to LangChain format."""
370394
messages = _convert_messages_to_langchain_format(prompt)
371395

372396
try:
373-
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks), stop=stop)
397+
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks))
374398
except Exception as e:
375399
_raise_llm_call_exception(e, llm)
376400

tests/test_actions_llm_utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from nemoguardrails.context import reasoning_trace_var, tool_calls_var
3535
from nemoguardrails.exceptions import LLMCallException
36+
from tests.utils import get_bound_llm_magic_mock
3637

3738

3839
@pytest.fixture(autouse=True)
@@ -547,18 +548,23 @@ def test_store_tool_calls_with_real_aimessage_multiple_tool_calls():
547548

548549
@pytest.mark.asyncio
549550
@pytest.mark.parametrize("llm_params", [None, {}])
550-
async def test_llm_call_stop_tokens_passed_without_llm_params(llm_params):
551-
"""Stop tokens must be passed to ainvoke even when llm_params is None or empty."""
552-
from unittest.mock import AsyncMock, MagicMock
553-
551+
@pytest.mark.parametrize("stop", [None, ["User:"]])
552+
async def test_llm_call_stop_tokens_passed_without_llm_params(llm_params, stop):
553+
"""Stop tokens must be passed to bind or ainvoke even when llm_params is None or empty."""
554554
from nemoguardrails.actions.llm.utils import llm_call
555555

556-
mock_llm = AsyncMock()
557-
mock_llm.ainvoke.return_value = MagicMock(content="response")
556+
mock_llm = get_bound_llm_magic_mock(ainvoke_return_value={"content": "response"})
558557

559-
await llm_call(mock_llm, "prompt", stop=["User:"], llm_params=llm_params)
558+
await llm_call(mock_llm, "prompt", stop=stop, llm_params=llm_params)
560559

561-
assert mock_llm.ainvoke.call_args[1]["stop"] == ["User:"]
560+
if mock_llm.bind.called:
561+
# Option A: Check if .bind() was called with the stop tokens
562+
args, kwargs = mock_llm.bind.call_args
563+
assert kwargs.get("stop", None) == stop
564+
else:
565+
# Option B: Check if it fell back to passing stop to .ainvoke
566+
args, kwargs = mock_llm.ainvoke.call_args
567+
assert kwargs.get("stop", None) == stop
562568

563569

564570
@pytest.mark.asyncio
@@ -677,6 +683,11 @@ class TestFilterParamsForOpenAIReasoningModels:
677683
("gpt-5-nano", {"temperature": 0.001}, {}),
678684
("o1-preview", {"max_tokens": 100}, {"max_tokens": 100}),
679685
("o1-preview", {}, {}),
686+
("gpt-5", {"stop": "stop"}, {}),
687+
("gpt-5-mini", {"temperature": 0.5, "max_tokens": 100, "stop": "stop"}, {"max_tokens": 100}),
688+
("o4-mini", {"stop": "stop"}, {}),
689+
("o3", {"stop": "stop"}, {}),
690+
("o3-pro", {"temperature": 0.5, "stop": "stop"}, {}),
680691
],
681692
)
682693
def test_filter_params(self, model, params, expected):
@@ -694,3 +705,10 @@ def test_does_not_modify_original_params(self):
694705
params = {"temperature": 0.5, "max_tokens": 100}
695706
_filter_params_for_openai_reasoning_models(llm, params)
696707
assert params == {"temperature": 0.5, "max_tokens": 100}
708+
709+
@pytest.mark.asyncio
710+
async def test_llm_call_does_not_mutate_llm_params(self):
711+
mock_llm = get_bound_llm_magic_mock(ainvoke_return_value={"content": "response"})
712+
original_params = {"max_tokens": 100}
713+
await llm_call(mock_llm, "prompt", stop=["User:"], llm_params=original_params)
714+
assert original_params == {"max_tokens": 100}

tests/test_llmrails.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import os
1717
from typing import Optional
18-
from unittest.mock import MagicMock, patch
18+
from unittest.mock import patch
1919

2020
import pytest
2121
from langchain_core.language_models import BaseChatModel
@@ -24,7 +24,7 @@
2424
from nemoguardrails.logging.explain import ExplainInfo
2525
from nemoguardrails.rails.llm.config import Model
2626
from tests.conftest import REASONING_TRACE_MOCK_PATH
27-
from tests.utils import FakeLLM, clean_events, event_sequence_conforms
27+
from tests.utils import FakeLLM, clean_events, event_sequence_conforms, get_bound_llm_magic_mock
2828

2929

3030
@pytest.fixture
@@ -1059,7 +1059,7 @@ def test_explain_calls_ensure_explain_info():
10591059
"""Make sure if no `explain_info` attribute is present in LLMRails it's populated with
10601060
an empty ExplainInfo object"""
10611061

1062-
mock_llm = MagicMock(spec=BaseChatModel)
1062+
mock_llm = get_bound_llm_magic_mock(ainvoke_return_value={"spec": BaseChatModel})
10631063
config = RailsConfig.from_content(config={"models": []})
10641064
rails = LLMRails(config=config, llm=mock_llm)
10651065
rails.generate(messages=[{"role": "user", "content": "Hi!"}])

tests/test_tool_calling_passthrough_only.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@
1515

1616
"""Test that tool calling ONLY works in passthrough mode."""
1717

18-
from unittest.mock import AsyncMock, MagicMock
18+
from unittest.mock import MagicMock
1919

2020
import pytest
2121
from langchain_core.messages import AIMessage
2222

2323
from nemoguardrails import LLMRails, RailsConfig
2424
from nemoguardrails.actions.llm.generation import LLMGenerationActions
2525
from nemoguardrails.context import tool_calls_var
26+
from tests.utils import get_bound_llm_magic_mock
2627

2728

2829
@pytest.fixture
2930
def mock_llm_with_tool_calls():
3031
"""Mock LLM that returns tool calls."""
31-
llm = AsyncMock()
32-
3332
mock_response = AIMessage(
3433
content="",
3534
tool_calls=[
@@ -41,8 +40,7 @@ def mock_llm_with_tool_calls():
4140
}
4241
],
4342
)
44-
llm.ainvoke.return_value = mock_response
45-
llm.invoke.return_value = mock_response
43+
llm = get_bound_llm_magic_mock(ainvoke_return_value=mock_response)
4644
return llm
4745

4846

tests/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
import sys
2020
from datetime import datetime, timedelta, timezone
2121
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
22+
from unittest.mock import AsyncMock, MagicMock
2223

2324
from langchain_core.callbacks.manager import (
2425
AsyncCallbackManagerForLLMRun,
2526
CallbackManagerForLLMRun,
2627
)
2728
from langchain_core.language_models import LLM
29+
from langchain_core.messages import AIMessage
2830

2931
from nemoguardrails import LLMRails, RailsConfig
3032
from nemoguardrails.colang import parse_colang_file
@@ -414,3 +416,21 @@ def _init_state(colang_content, yaml_content: Optional[str] = None) -> State:
414416
json.dump(state.flow_configs, sys.stdout, indent=4, cls=EnhancedJsonEncoder)
415417

416418
return state
419+
420+
421+
def get_bound_llm_magic_mock(ainvoke_return_value: Union[AIMessage, dict]) -> MagicMock:
422+
mock_llm = MagicMock()
423+
mock_llm.return_value = mock_llm
424+
425+
bound_llm_mock = AsyncMock()
426+
if isinstance(ainvoke_return_value, dict):
427+
bound_llm_mock.ainvoke.return_value = MagicMock(**ainvoke_return_value)
428+
else:
429+
bound_llm_mock.ainvoke.return_value = ainvoke_return_value
430+
431+
mock_llm.bind.return_value = bound_llm_mock
432+
if isinstance(ainvoke_return_value, dict):
433+
mock_llm.ainvoke = AsyncMock(return_value=MagicMock(**ainvoke_return_value))
434+
else:
435+
mock_llm.ainvoke = AsyncMock(return_value=ainvoke_return_value)
436+
return mock_llm

0 commit comments

Comments
 (0)