Skip to content

Commit d3d7f77

Browse files
authored
feat: New top-level scaffold (#1613)
* Initial commit of new top-level object, not tests yet * Consolidate into one file, add all but internal implementation-detail methods to Guardrails object * Add Guardrails top-level tests * Compacting tests * Use NEMO_USE_GUARDRAILS_WRAPPER to select new wrapper on top of LLMRails * Clean up init method * Change env var from NEMO_USE_GUARDRAILS_WRAPPER to NEMO_GUARDRAILS_IORAILS_ENGINE
1 parent f4824e2 commit d3d7f77

File tree

5 files changed

+882
-3
lines changed

5 files changed

+882
-3
lines changed

nemoguardrails/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,27 @@
2626

2727
import warnings
2828

29-
from . import patch_asyncio
30-
from .rails import LLMRails, RailsConfig
29+
import nemoguardrails.patch_asyncio
30+
from nemoguardrails.rails import RailsConfig
3131

32-
patch_asyncio.apply()
32+
nemoguardrails.patch_asyncio.apply()
3333

3434
# Ignore a warning message from torch.
3535
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
3636

37+
# Use Guardrails top-level if this environment variable is set
38+
_use_guardrails_wrapper = os.environ.get("NEMO_GUARDRAILS_IORAILS_ENGINE", "").lower() in (
39+
"true",
40+
"1",
41+
"yes",
42+
)
43+
44+
if _use_guardrails_wrapper:
45+
# Use the Guardrails wrapper class (aliased as LLMRails for compatibility)
46+
from nemoguardrails.guardrails.guardrails import Guardrails as LLMRails
47+
else:
48+
# Use the original LLMRails class
49+
from nemoguardrails.rails import LLMRails
50+
3751
__version__ = version("nemoguardrails")
3852
__all__ = ["LLMRails", "RailsConfig"]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Top-level Guardrails interface module.
17+
18+
This module provides a simplified, user-friendly interface for interacting with
19+
NeMo Guardrails. The Guardrails class wraps the LLMRails functionality and provides
20+
a streamlined API for generating LLM responses with programmable guardrails.
21+
"""
22+
23+
from enum import Enum
24+
from typing import AsyncIterator, Optional, Tuple, TypeAlias, Union, overload
25+
26+
from langchain_core.language_models import BaseChatModel, BaseLLM
27+
28+
from nemoguardrails.logging.explain import ExplainInfo
29+
from nemoguardrails.rails.llm.config import RailsConfig
30+
from nemoguardrails.rails.llm.llmrails import LLMRails
31+
from nemoguardrails.rails.llm.options import GenerationResponse
32+
33+
34+
class MessageRole(str, Enum):
35+
"""Enumeration of message roles in a conversation."""
36+
37+
USER = "user"
38+
ASSISTANT = "assistant"
39+
SYSTEM = "system"
40+
CONTEXT = "context"
41+
EVENT = "event"
42+
TOOL = "tool"
43+
44+
45+
LLMMessages: TypeAlias = list[dict[str, str]]
46+
47+
48+
class Guardrails:
49+
"""Top-level interface for NeMo Guardrails functionality."""
50+
51+
def __init__(
52+
self,
53+
config: RailsConfig,
54+
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
55+
verbose: bool = False,
56+
):
57+
"""Initialize a Guardrails instance."""
58+
59+
self.config = config
60+
self.llm = llm
61+
self.verbose = verbose
62+
63+
self.llmrails = LLMRails(config, llm, verbose)
64+
65+
@staticmethod
66+
def _convert_to_messages(prompt: str | None = None, messages: LLMMessages | None = None) -> LLMMessages:
67+
"""Convert prompt or simplified messages to LLMRails standard format.
68+
69+
Converts from Guardrails simplified format to LLMRails standard format:
70+
- Simplified: [{"user": "text"}]
71+
- Standard: [{"role": "user", "content": "Hello"}]
72+
"""
73+
74+
# Priority: messages first, then prompt
75+
if messages:
76+
return messages
77+
78+
if prompt:
79+
# Convert string prompt to standard format
80+
return [{"role": "user", "content": prompt}]
81+
82+
raise ValueError("Neither prompt nor messages provided for generation")
83+
84+
def generate(
85+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
86+
) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]:
87+
"""Generate an LLM response synchronously with guardrails applied."""
88+
89+
messages = self._convert_to_messages(prompt, messages)
90+
return self.llmrails.generate(messages=messages, **kwargs)
91+
92+
@overload
93+
async def generate_async(self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs) -> str: ...
94+
95+
@overload
96+
async def generate_async(
97+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
98+
) -> dict: ...
99+
100+
@overload
101+
async def generate_async(
102+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
103+
) -> GenerationResponse: ...
104+
105+
@overload
106+
async def generate_async(
107+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
108+
) -> tuple[dict, dict]: ...
109+
110+
async def generate_async(
111+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
112+
) -> str | dict | GenerationResponse | tuple[dict, dict]:
113+
"""Generate an LLM response asynchronously with guardrails applied."""
114+
115+
messages = self._convert_to_messages(prompt, messages)
116+
response = await self.llmrails.generate_async(messages=messages, **kwargs)
117+
return response
118+
119+
def stream_async(
120+
self, prompt: str | None = None, messages: LLMMessages | None = None, **kwargs
121+
) -> AsyncIterator[str | dict]:
122+
"""Generate an LLM response asynchronously with streaming support."""
123+
124+
messages = self._convert_to_messages(prompt, messages)
125+
return self.llmrails.stream_async(messages=messages, **kwargs)
126+
127+
def explain(self) -> ExplainInfo:
128+
"""Get the latest ExplainInfo object for debugging."""
129+
return self.llmrails.explain()
130+
131+
def update_llm(self, llm: Union[BaseLLM, BaseChatModel]) -> None:
132+
"""Replace the main LLM with a new one."""
133+
self.llm = llm
134+
self.llmrails.update_llm(llm)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ include = [
173173
"nemoguardrails/server/**",
174174
"tests/test_callbacks.py",
175175
"nemoguardrails/benchmark/**",
176+
"nemoguardrails/guardrails/**"
176177
]
177178
exclude = [
178179
"nemoguardrails/llm/providers/trtllm/**",

0 commit comments

Comments
 (0)