Skip to content

Commit 74d52d5

Browse files
authored
feat(iorails): Use single engine with correct lifecycle (#1649)
* Initial checkin of code, no existing tests broken * Move test data into its own file * Clean up tests * Rewrite logic on when IORails is used vs LLMRails * Remove todos (will do in future PR), fix docstring for _last_content_by_role, remove unused guardrails_models.py * Implement a single engine (IORails / LLMRails) and update tests * Standardise start/stop for objects which need lifespan hooks (i.e. creating worker tasks, opening clients, etc) * Fix start/stop ordering and add finally-clause to make sure self._running doesn't get into an inconsistent state * Clean up start/stop exception handling, fix inaccurate docstrings * Clean up docstrings and multiple-worker start/stop code in AsyncWorkQueue and ModelManager * Use _flow_name() to extract the flow name and fix urljoin() in model_engine.py * Fix IORails start/stop exception handling, add new tests to get line coverage back to 100% * Revert changes to model manager start/stop * Fix _has_only_iorails_flows set bug, update test_start_failure_allows_retry to match start/stop methods * Clean up async work queue start() and tests * Unpack GenerationOptions object (passed from server api.py) and pass on llm_params to the main LLM call in ModelManager * Unpack llm parameters to top-level rather than nested under llm_params in HTTP body * Reformat after rebasing onto develop * Rebase-merge cleanups
1 parent 15041b8 commit 74d52d5

File tree

11 files changed

+704
-233
lines changed

11 files changed

+704
-233
lines changed

nemoguardrails/guardrails/async_work_queue.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,42 @@ async def start(self) -> None:
7373
"""Starts the worker pool. Call this during service startup."""
7474
if self._running:
7575
return
76-
self._running = True
76+
7777
self._busy_count = 0
7878
self._workers = []
79-
for i in range(self._max_concurrency):
80-
task = asyncio.create_task(self._worker_loop(), name=f"{self._name}_worker_id{i}")
81-
self._workers.append(task)
79+
# Try to start all workers, cancelling any on failure
80+
try:
81+
for i in range(self._max_concurrency):
82+
task = asyncio.create_task(self._worker_loop(), name=f"{self._name}_worker_id{i}")
83+
self._workers.append(task)
84+
except Exception:
85+
# Cancel any tasks that did start
86+
for task in self._workers:
87+
task.cancel()
88+
await asyncio.gather(*self._workers, return_exceptions=True)
89+
self._workers = []
90+
raise
91+
92+
self._running = True
8293

8394
async def stop(self, wait_for_completion: bool = True) -> None:
8495
"""Stops the worker pool. Call this during service shutdown."""
8596
if not self._running:
8697
return
87-
self._running = False
8898

89-
if wait_for_completion:
90-
await self._queue.join()
99+
try:
100+
if wait_for_completion:
101+
await self._queue.join()
91102

92-
for task in self._workers:
93-
task.cancel()
103+
for task in self._workers:
104+
task.cancel()
94105

95-
# Swallow cancellations to prevent noise during shutdown
96-
await asyncio.gather(*self._workers, return_exceptions=True)
97-
self._workers = []
98-
self._busy_count = 0
106+
# Swallow cancellations to prevent noise during shutdown
107+
await asyncio.gather(*self._workers, return_exceptions=True)
108+
self._workers = []
109+
self._busy_count = 0
110+
finally:
111+
self._running = False
99112

100113
async def submit(self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any) -> T:
101114
"""

nemoguardrails/guardrails/guardrails.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""Top-level Guardrails interface module.
1717
1818
This module provides a simplified, user-friendly interface for interacting with
19-
NeMo Guardrails. The Guardrails class wraps the LLMRails functionality and provides
20-
a streamlined API for generating LLM responses with programmable guardrails.
19+
NeMo Guardrails. The Guardrails class wraps either IORails or LLMRails (chosen
20+
automatically based on config) and provides a streamlined API for generating
21+
LLM responses with programmable guardrails.
2122
"""
2223

2324
import logging
24-
from typing import AsyncIterator, Optional, Tuple, Union, overload
25+
from typing import AsyncIterator, Optional, Tuple, Union, cast, overload
2526

2627
from langchain_core.language_models import BaseChatModel, BaseLLM
2728

@@ -63,10 +64,8 @@ def __init__(
6364
self.verbose = verbose
6465

6566
# Whether to use IORailsEngine for inference requests
66-
self._use_iorails_engine: bool = use_iorails and self._has_only_iorails_flows()
67-
self._iorails = IORails(config)
68-
self._llmrails = LLMRails(config, llm, verbose)
69-
self.rails_engine = self._iorails if self._use_iorails_engine else self._llmrails
67+
use_iorails_engine = use_iorails and self._has_only_iorails_flows()
68+
self._rails_engine = IORails(config) if use_iorails_engine else LLMRails(config, llm, verbose)
7069

7170
# Async work queue for managing concurrent generate_async requests
7271
self._generate_async_queue: AsyncWorkQueue = AsyncWorkQueue(
@@ -80,22 +79,16 @@ def __init__(
8079
self._queues = [self._generate_async_queue]
8180

8281
@property
83-
def iorails(self) -> IORails:
84-
"""Get immutable IORails object"""
85-
return self._iorails
86-
87-
@property
88-
def llmrails(self) -> LLMRails:
82+
def rails_engine(self) -> IORails | LLMRails:
8983
"""Get immutable LLMRails object"""
90-
return self._llmrails
84+
return self._rails_engine
9185

9286
@staticmethod
9387
def _convert_to_messages(prompt: str | None = None, messages: LLMMessages | None = None) -> LLMMessages:
94-
"""Convert prompt or simplified messages to LLMRails standard format.
88+
"""Return messages in standard format, converting a prompt string if needed.
9589
96-
Converts from Guardrails simplified format to LLMRails standard format:
97-
- Simplified: [{"user": "text"}]
98-
- Standard: [{"role": "user", "content": "Hello"}]
90+
If messages is provided, returns it as-is.
91+
If prompt is provided, wraps it as [{"role": "user", "content": prompt}].
9992
"""
10093

10194
# Priority: messages first, then prompt
@@ -113,7 +106,7 @@ def _has_only_iorails_flows(self):
113106

114107
# If we have any rails outside of `input` and `output` we don't support them
115108
rails_set = self.config.rails.model_fields_set
116-
if rails_set > IORAILS_RAILS:
109+
if rails_set - IORAILS_RAILS:
117110
return False
118111

119112
for flow in self.config.rails.input.flows:
@@ -133,11 +126,14 @@ def generate(
133126
) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]:
134127
"""Generate an LLM response synchronously with guardrails applied."""
135128

136-
if self._use_iorails_engine:
129+
if isinstance(self.rails_engine, IORails):
137130
raise NotImplementedError("IORails doesn't support generate()")
138131

139132
generate_messages = self._convert_to_messages(prompt, messages)
140-
response = self._llmrails.generate(messages=generate_messages, **kwargs)
133+
134+
# self.rails_engine must be LLMRails since we raise above if we're using IORails
135+
llmrails = cast(LLMRails, self.rails_engine)
136+
response = llmrails.generate(messages=generate_messages, **kwargs)
141137
return response
142138

143139
@overload
@@ -178,46 +174,56 @@ def stream_async(
178174
Only supported when using LLMRails
179175
"""
180176

181-
if self._use_iorails_engine:
177+
if isinstance(self.rails_engine, IORails):
182178
raise NotImplementedError("IORails doesn't support stream_async()")
183179

184180
stream_messages = self._convert_to_messages(prompt, messages)
185-
return self._llmrails.stream_async(messages=stream_messages, **kwargs)
181+
# self.rails_engine must be LLMRails since we raise above if we're using IORails
182+
llmrails = cast(LLMRails, self.rails_engine)
183+
return llmrails.stream_async(messages=stream_messages, **kwargs)
186184

187185
def explain(self) -> ExplainInfo:
188186
"""Get the latest ExplainInfo object for debugging.
189187
Only supported for LLMRails
190188
"""
191189

192-
if self._use_iorails_engine:
190+
if isinstance(self.rails_engine, IORails):
193191
raise NotImplementedError("IORails doesn't support explain()")
194192

195-
return self._llmrails.explain()
193+
# self.rails_engine must be LLMRails since we raise above if we're using IORails
194+
llmrails = cast(LLMRails, self.rails_engine)
195+
return llmrails.explain()
196196

197197
def update_llm(self, llm: Union[BaseLLM, BaseChatModel]) -> None:
198198
"""Replace the main LLM with a new one.
199199
Only supported for LLMRails, since IORails doesn't take LLM as argument
200200
"""
201-
if self._use_iorails_engine:
201+
if isinstance(self.rails_engine, IORails):
202202
raise NotImplementedError("IORails doesn't support update_llm()")
203203

204-
self._llmrails.update_llm(llm)
204+
# self.rails_engine must be LLMRails since we raise above if we're using IORails
205+
llmrails = cast(LLMRails, self.rails_engine)
206+
llmrails.update_llm(llm)
205207

206208
async def startup(self) -> None:
207-
"""Lifecycle method to create worker threads and infrastructure"""
209+
"""Lifecycle method to start async worker tasks and the rails engine"""
208210
for queue in self._queues:
209211
await queue.start()
212+
if isinstance(self.rails_engine, IORails):
213+
await self.rails_engine.start()
210214

211215
async def shutdown(self) -> None:
212-
"""Lifecycle method to cleanly shutdown worker threads and infrastructure"""
216+
"""Lifecycle method to stop async worker tasks and the rails engine"""
213217
for queue in self._queues:
214218
await queue.stop()
219+
if isinstance(self.rails_engine, IORails):
220+
await self.rails_engine.stop()
215221

216222
async def __aenter__(self):
217-
"""Async context manager entry - starts the queues."""
223+
"""Async context manager entry."""
218224
await self.startup()
219225
return self
220226

221227
async def __aexit__(self, exc_type, exc_val, exc_tb):
222-
"""Async context manager exit - shuts down the queues."""
228+
"""Async context manager exit."""
223229
await self.shutdown()

nemoguardrails/guardrails/iorails.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
"""Optimized IORails Engine for specific guardrail configurations.
1717
1818
This module provides an optimized inference path for guardrail configurations that
19-
only use specific supported flows (input/output content safety, topic safety,
20-
jailbreak detection, etc.). For configurations outside this supported set, the
21-
standard LLMRails engine should be used instead.
19+
only use specific supported flows (input/output content safety). For configurations
20+
outside this supported set, the standard LLMRails engine should be used instead.
2221
"""
2322

2423
import logging
@@ -27,6 +26,7 @@
2726
from nemoguardrails.guardrails.model_manager import ModelManager
2827
from nemoguardrails.guardrails.rails_manager import RailsManager
2928
from nemoguardrails.rails.llm.config import RailsConfig
29+
from nemoguardrails.rails.llm.options import GenerationOptions
3030

3131
log = logging.getLogger(__name__)
3232

@@ -37,12 +37,46 @@ class IORails:
3737
"""Workflow engine for accelerated Input/Output rails inference."""
3838

3939
def __init__(self, config: RailsConfig) -> None:
40+
self._running = False
41+
4042
# Model Manager has one or more ModelEngine inside. Each ModelEngine calls a single model or API
4143
self.model_manager = ModelManager(config.models)
4244

4345
# Rails Manager is responsible for running rails by making calls to Model Manager
4446
self.rails_manager = RailsManager(config, self.model_manager)
4547

48+
async def start(self) -> None:
49+
"""Start the IORails engine. Call this during service startup."""
50+
if self._running:
51+
return
52+
53+
# When starting up, make sure self._running is always set to True even on exceptions.
54+
# This allows the stop() method to clean up any state
55+
try:
56+
await self.model_manager.start()
57+
finally:
58+
self._running = True
59+
60+
async def stop(self) -> None:
61+
"""Stop the IORails engine. Call this during service shutdown."""
62+
if not self._running:
63+
return
64+
65+
# If any exceptions are thrown when stopping ModelManager, set the _running to False
66+
try:
67+
await self.model_manager.stop()
68+
finally:
69+
self._running = False
70+
71+
async def __aenter__(self):
72+
"""Context manager (used for testing rather than long-lived instance)"""
73+
await self.start()
74+
return self
75+
76+
async def __aexit__(self, exc_type, exc_val, exc_tb):
77+
"""Context manager (used for testing rather than long-lived instance)"""
78+
await self.stop()
79+
4680
async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
4781
"""Run input rails, generation, and output rails. Return response if safe."""
4882

@@ -53,7 +87,14 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
5387
return {"role": "assistant", "content": REFUSAL_MESSAGE}
5488

5589
# Step 2: Generate response from main LLM
56-
response_text = await self.model_manager.generate_async("main", messages)
90+
# If we got an `options=GenerationOptions`, then unpack GenerationOptions.llm_params and add
91+
# that to the main LLM call
92+
llm_kwargs = {}
93+
if kwargs.get("options") and isinstance(kwargs["options"], GenerationOptions):
94+
generation_options = kwargs["options"]
95+
llm_kwargs = generation_options.llm_params if generation_options.llm_params else {}
96+
97+
response_text = await self.model_manager.generate_async("main", messages, **llm_kwargs)
5798

5899
# Step 3: Check output rails
59100
output_result = await self.rails_manager.is_output_safe(messages, response_text)

nemoguardrails/guardrails/model_engine.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import logging
2424
import os
2525
from typing import Any, Optional, cast
26-
from urllib.parse import urljoin
2726

2827
import aiohttp
2928
from aiohttp_retry import ExponentialRetry, RetryClient
@@ -83,20 +82,30 @@ def __init__(self, model_config: Model) -> None:
8382
exceptions={aiohttp.ClientConnectionError},
8483
)
8584
self._client: Optional[RetryClient] = None
85+
self._running = False
8686

8787
async def start(self) -> None:
88-
"""Create this engine's RetryClient."""
89-
if self._client is None:
90-
self._client = RetryClient(
91-
retry_options=self._retry_options,
92-
client_session=aiohttp.ClientSession(timeout=self._timeout),
93-
)
88+
"""Create this engine's RetryClient. Call this during service startup."""
89+
if self._running:
90+
return
91+
92+
self._client = RetryClient(
93+
retry_options=self._retry_options,
94+
client_session=aiohttp.ClientSession(timeout=self._timeout),
95+
)
96+
self._running = True
9497

9598
async def stop(self) -> None:
96-
"""Close this engine's RetryClient."""
97-
if self._client:
98-
await self._client.close()
99-
self._client = None
99+
"""Close this engine's RetryClient. Call this during service shutdown."""
100+
if not self._running:
101+
return
102+
103+
try:
104+
if self._client:
105+
await self._client.close()
106+
self._client = None
107+
finally:
108+
self._running = False
100109

101110
def _resolve_base_url(self) -> str:
102111
"""Resolve the base URL from model parameters or engine type."""
@@ -116,9 +125,7 @@ def _resolve_base_url(self) -> str:
116125
def _get_environment_variable(self, variable_name: str) -> str | None:
117126
"""Return the value stored in environment variable `variable_name`."""
118127
env_value = os.environ.get(variable_name)
119-
if env_value:
120-
return env_value
121-
return None
128+
return env_value
122129

123130
def _resolve_api_key(self, engine: str | None) -> Optional[str]:
124131
"""Resolve the API key from model config or environment."""
@@ -159,17 +166,17 @@ async def call(
159166
The parsed JSON response dict from the API.
160167
161168
Raises:
162-
ModelEngineError: If the request fails after all retries or the client is not started.
169+
ModelEngineError: If the request fails after all retries.
163170
"""
164171

165172
# Lazy-initialize client if `start()` hasn't yet been called.
166-
if self._client is None:
173+
if not self._running:
167174
await self.start()
168175

169176
# Cast as RetryClient so type-checking knows it isn't None
170177
client = cast(RetryClient, self._client)
171178

172-
url = urljoin(self.base_url, _CHAT_COMPLETIONS_ENDPOINT)
179+
url = self.base_url.rstrip("/") + _CHAT_COMPLETIONS_ENDPOINT
173180

174181
headers: dict[str, str] = {"Content-Type": "application/json"}
175182
if self.api_key:
@@ -200,6 +207,15 @@ async def call(
200207
model_name=self.model_name,
201208
) from exc
202209

210+
async def __aenter__(self):
211+
"""Context manager (used for testing rather than long-lived instance)"""
212+
await self.start()
213+
return self
214+
215+
async def __aexit__(self, exc_type, exc_val, exc_tb):
216+
"""Context manager (used for testing rather than long-lived instance)"""
217+
await self.stop()
218+
203219

204220
async def _safe_read_body(response: aiohttp.ClientResponse) -> str:
205221
"""Read response body for error messages, truncating if too large."""

0 commit comments

Comments
 (0)