Skip to content

Commit e999078

Browse files
authored
feat(integration): add GuardrailsMiddleware for LangChain agent (#1606)
* feat(langchain): add GuardrailsMiddleware for LangChain agent integration * feat(middleware): add explicit rail_types to check_async and fix message replacement - Add RailType enum (INPUT, OUTPUT) to options.py - Add optional rail_types parameter to check_async/check to override auto-detection - Middleware now passes rail_types=[RailType.INPUT] from abefore_model and rail_types=[RailType.OUTPUT] from aafter_model - Fix _replace_last_ai_message to find actual AIMessage index instead of assuming messages[-1] - Add unit tests for explicit rail type passing and message replacement
1 parent b0701bd commit e999078

File tree

8 files changed

+2692
-75
lines changed

8 files changed

+2692
-75
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Optional
17+
18+
from nemoguardrails.rails.llm.options import RailsResult
19+
20+
21+
class GuardrailViolation(Exception):
22+
def __init__(
23+
self,
24+
message: str,
25+
result: Optional[RailsResult] = None,
26+
rail_type: Optional[str] = None,
27+
):
28+
super().__init__(message)
29+
self.result = result
30+
self.rail_type = rail_type
31+
32+
def __str__(self):
33+
return super().__str__()
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
20+
21+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
22+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
23+
24+
if TYPE_CHECKING:
25+
from langgraph.runtime import Runtime as LangGraphRuntime
26+
from nemoguardrails.integrations.langchain.exceptions import GuardrailViolation
27+
from nemoguardrails.integrations.langchain.message_utils import (
28+
create_ai_message,
29+
is_ai_message,
30+
is_human_message,
31+
messages_to_dicts,
32+
)
33+
from nemoguardrails.rails.llm.config import RailsConfig
34+
from nemoguardrails.rails.llm.llmrails import LLMRails
35+
from nemoguardrails.rails.llm.options import RailsResult, RailStatus, RailType
36+
from nemoguardrails.utils import get_or_create_event_loop
37+
38+
log = logging.getLogger(__name__)
39+
40+
41+
class GuardrailsMiddleware(AgentMiddleware):
42+
def __init__(
43+
self,
44+
config_path: Optional[str] = None,
45+
config_yaml: Optional[str] = None,
46+
raise_on_violation: bool = False,
47+
blocked_input_message: str = "I cannot process this request due to content policy.",
48+
blocked_output_message: str = "I cannot provide this response due to content policy.",
49+
enable_input_rails: bool = True,
50+
enable_output_rails: bool = True,
51+
):
52+
if config_path is not None:
53+
config = RailsConfig.from_path(config_path)
54+
elif config_yaml is not None:
55+
config = RailsConfig.from_content(config_yaml)
56+
else:
57+
raise ValueError("Either 'config_path' or 'config_yaml' must be provided to GuardrailsMiddleware")
58+
59+
self.rails = LLMRails(config=config)
60+
self.raise_on_violation = raise_on_violation
61+
self.blocked_input_message = blocked_input_message
62+
self.blocked_output_message = blocked_output_message
63+
self.enable_input_rails = enable_input_rails
64+
self.enable_output_rails = enable_output_rails
65+
66+
def _has_input_rails(self) -> bool:
67+
return len(self.rails.config.rails.input.flows) > 0
68+
69+
def _has_output_rails(self) -> bool:
70+
return len(self.rails.config.rails.output.flows) > 0
71+
72+
def _convert_to_rails_messages(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]:
73+
return messages_to_dicts(messages)
74+
75+
def _get_last_user_message(self, messages: List[BaseMessage]) -> Optional[HumanMessage]:
76+
for msg in reversed(messages):
77+
if is_human_message(msg):
78+
return msg
79+
return None
80+
81+
def _get_last_ai_message(self, messages: List[BaseMessage]) -> Optional[AIMessage]:
82+
for msg in reversed(messages):
83+
if is_ai_message(msg):
84+
return msg
85+
return None
86+
87+
def _handle_guardrail_failure(
88+
self,
89+
result: RailsResult,
90+
rail_type: str,
91+
blocked_message: str,
92+
) -> None:
93+
if result.status == RailStatus.BLOCKED:
94+
failure_message = f"{rail_type.capitalize()} blocked by {result.rail or 'unknown rail'}"
95+
96+
if self.raise_on_violation:
97+
raise GuardrailViolation(
98+
message=failure_message,
99+
result=result,
100+
rail_type=rail_type,
101+
)
102+
103+
log.warning(failure_message)
104+
105+
@hook_config(can_jump_to=["end"])
106+
async def abefore_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
107+
if not self.enable_input_rails or not self._has_input_rails():
108+
return None
109+
110+
messages = state.get("messages", [])
111+
if not messages:
112+
return None
113+
114+
rails_messages = self._convert_to_rails_messages(messages)
115+
116+
try:
117+
result = await self.rails.check_async(rails_messages, rail_types=[RailType.INPUT])
118+
119+
if result.status == RailStatus.BLOCKED:
120+
self._handle_guardrail_failure(
121+
result=result,
122+
rail_type="input",
123+
blocked_message=self.blocked_input_message,
124+
)
125+
blocked_msg = create_ai_message(self.blocked_input_message)
126+
return {"messages": messages + [blocked_msg], "jump_to": "end"}
127+
128+
return None
129+
130+
except GuardrailViolation:
131+
raise
132+
except Exception as e:
133+
log.error(f"Error checking input rails: {e}", exc_info=True)
134+
135+
if self.raise_on_violation:
136+
raise GuardrailViolation(
137+
message=f"Input rail execution error: {str(e)}",
138+
rail_type="input",
139+
)
140+
141+
blocked_msg = create_ai_message(self.blocked_input_message)
142+
return {"messages": messages + [blocked_msg], "jump_to": "end"}
143+
144+
def _replace_last_ai_message(self, messages: list, replacement: AIMessage) -> list:
145+
for i in range(len(messages) - 1, -1, -1):
146+
if is_ai_message(messages[i]):
147+
return messages[:i] + [replacement] + messages[i + 1 :]
148+
return messages + [replacement]
149+
150+
async def aafter_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
151+
if not self.enable_output_rails or not self._has_output_rails():
152+
return None
153+
154+
messages = state.get("messages", [])
155+
if not messages:
156+
return None
157+
158+
last_ai_message = self._get_last_ai_message(messages)
159+
if not last_ai_message:
160+
return None
161+
162+
rails_messages = self._convert_to_rails_messages(messages)
163+
164+
try:
165+
result = await self.rails.check_async(rails_messages, rail_types=[RailType.OUTPUT])
166+
167+
if result.status == RailStatus.BLOCKED:
168+
self._handle_guardrail_failure(
169+
result=result,
170+
rail_type="output",
171+
blocked_message=self.blocked_output_message,
172+
)
173+
blocked_msg = create_ai_message(self.blocked_output_message)
174+
return {"messages": self._replace_last_ai_message(messages, blocked_msg)}
175+
176+
return None
177+
178+
except GuardrailViolation:
179+
raise
180+
except Exception as e:
181+
log.error(f"Error checking output rails: {e}", exc_info=True)
182+
183+
if self.raise_on_violation:
184+
raise GuardrailViolation(
185+
message=f"Output rail execution error: {str(e)}",
186+
rail_type="output",
187+
)
188+
189+
blocked_msg = create_ai_message(self.blocked_output_message)
190+
return {"messages": self._replace_last_ai_message(messages, blocked_msg)}
191+
192+
@hook_config(can_jump_to=["end"])
193+
def before_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
194+
if not self.enable_input_rails or not self._has_input_rails():
195+
return None
196+
197+
messages = state.get("messages", [])
198+
if not messages:
199+
return None
200+
201+
loop = get_or_create_event_loop()
202+
return loop.run_until_complete(self.abefore_model(state, runtime))
203+
204+
def after_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
205+
if not self.enable_output_rails or not self._has_output_rails():
206+
return None
207+
208+
messages = state.get("messages", [])
209+
if not messages:
210+
return None
211+
212+
last_ai_message = self._get_last_ai_message(messages)
213+
if not last_ai_message:
214+
return None
215+
216+
loop = get_or_create_event_loop()
217+
return loop.run_until_complete(self.aafter_model(state, runtime))
218+
219+
220+
class InputRailsMiddleware(GuardrailsMiddleware):
221+
def __init__(
222+
self,
223+
config_path: Optional[str] = None,
224+
config_yaml: Optional[str] = None,
225+
raise_on_violation: bool = False,
226+
blocked_input_message: str = "I cannot process this request due to content policy.",
227+
):
228+
super().__init__(
229+
config_path=config_path,
230+
config_yaml=config_yaml,
231+
raise_on_violation=raise_on_violation,
232+
blocked_input_message=blocked_input_message,
233+
blocked_output_message="",
234+
enable_input_rails=True,
235+
enable_output_rails=False,
236+
)
237+
238+
async def aafter_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
239+
return None
240+
241+
def after_agent(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
242+
return None
243+
244+
245+
class OutputRailsMiddleware(GuardrailsMiddleware):
246+
def __init__(
247+
self,
248+
config_path: Optional[str] = None,
249+
config_yaml: Optional[str] = None,
250+
raise_on_violation: bool = False,
251+
blocked_output_message: str = "I cannot provide this response due to content policy.",
252+
):
253+
super().__init__(
254+
config_path=config_path,
255+
config_yaml=config_yaml,
256+
raise_on_violation=raise_on_violation,
257+
blocked_input_message="",
258+
blocked_output_message=blocked_output_message,
259+
enable_input_rails=False,
260+
enable_output_rails=True,
261+
)
262+
263+
@hook_config(can_jump_to=["end"])
264+
async def abefore_model(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
265+
return None
266+
267+
@hook_config(can_jump_to=["end"])
268+
def before_agent(self, state: AgentState, runtime: LangGraphRuntime) -> Optional[Dict[str, Any]]:
269+
return None

0 commit comments

Comments
 (0)