Skip to content

Commit f4824e2

Browse files
feat(server): make guardrails server OpenAI compatible (#1340)
--------- Co-authored-by: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com>
1 parent 0e4efc6 commit f4824e2

File tree

12 files changed

+2830
-487
lines changed

12 files changed

+2830
-487
lines changed

nemoguardrails/server/api.py

Lines changed: 245 additions & 146 deletions
Large diffs are not rendered by default.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""OpenAI API schema definitions for the NeMo Guardrails server."""
17+
18+
import os
19+
from typing import Any, List, Optional, Union
20+
21+
from openai.types.chat.chat_completion import ChatCompletion
22+
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
23+
24+
from nemoguardrails.rails.llm.options import GenerationOptions
25+
26+
27+
class GuardrailsDataOutput(BaseModel):
28+
"""Guardrails-specific output data."""
29+
30+
config_id: Optional[str] = Field(
31+
default=None,
32+
description="The guardrails configuration ID associated with this response.",
33+
)
34+
state: Optional[dict] = Field(default=None, description="State object for continuing the conversation.")
35+
llm_output: Optional[dict] = Field(default=None, description="Additional LLM output data.")
36+
output_data: Optional[dict] = Field(default=None, description="Additional output data.")
37+
log: Optional[dict] = Field(default=None, description="Generation log data.")
38+
39+
40+
class GuardrailsChatCompletion(ChatCompletion):
41+
"""OpenAI API response body with NeMo-Guardrails extensions."""
42+
43+
guardrails: Optional[GuardrailsDataOutput] = Field(default=None, description="Guardrails specific output data.")
44+
45+
46+
class OpenAIChatCompletionRequest(BaseModel):
47+
"""Standard OpenAI chat completion request parameters."""
48+
49+
messages: Optional[List[dict]] = Field(
50+
default=None,
51+
description="The list of messages in the current conversation.",
52+
)
53+
model: str = Field(
54+
...,
55+
description="The LLM model to use for chat completion (e.g., 'gpt-4o', 'llama-3.1-8b').",
56+
)
57+
stream: Optional[bool] = Field(
58+
default=False,
59+
description="If set, partial message deltas will be sent as server-sent events.",
60+
)
61+
max_tokens: Optional[int] = Field(
62+
default=None,
63+
description="The maximum number of tokens to generate.",
64+
)
65+
temperature: Optional[float] = Field(
66+
default=None,
67+
description="Sampling temperature to use.",
68+
)
69+
top_p: Optional[float] = Field(
70+
default=None,
71+
description="Top-p sampling parameter.",
72+
)
73+
stop: Optional[Union[str, List[str]]] = Field(
74+
default=None,
75+
description="Stop sequences.",
76+
)
77+
presence_penalty: Optional[float] = Field(
78+
default=None,
79+
description="Presence penalty parameter.",
80+
)
81+
frequency_penalty: Optional[float] = Field(
82+
default=None,
83+
description="Frequency penalty parameter.",
84+
)
85+
function_call: Optional[dict] = Field(
86+
default=None,
87+
description="Function call parameter.",
88+
)
89+
logit_bias: Optional[dict] = Field(
90+
default=None,
91+
description="Logit bias parameter.",
92+
)
93+
logprobs: Optional[bool] = Field(
94+
default=None,
95+
description="Log probabilities parameter.",
96+
)
97+
98+
99+
class GuardrailsDataInput(BaseModel):
100+
"""Guardrails-specific options for the request."""
101+
102+
config_id: Optional[str] = Field(
103+
default_factory=lambda: os.getenv("DEFAULT_CONFIG_ID", None),
104+
description="The guardrails configuration ID to use.",
105+
)
106+
config_ids: Optional[List[str]] = Field(
107+
default=None,
108+
description="List of configuration IDs to combine.",
109+
validate_default=True,
110+
)
111+
thread_id: Optional[str] = Field(
112+
default=None,
113+
min_length=16,
114+
max_length=255,
115+
description="The ID of an existing thread to continue.",
116+
)
117+
context: Optional[dict] = Field(
118+
default=None,
119+
description="Additional context data for the conversation.",
120+
)
121+
options: GenerationOptions = Field(
122+
default_factory=GenerationOptions,
123+
description="Additional generation options.",
124+
)
125+
state: Optional[dict] = Field(
126+
default=None,
127+
description="State object to continue the interaction.",
128+
)
129+
130+
@model_validator(mode="before")
131+
@classmethod
132+
def validate_config_ids(cls, data: Any) -> Any:
133+
if isinstance(data, dict):
134+
if data.get("config_id") is not None and data.get("config_ids") is not None:
135+
raise ValueError("Only one of config_id or config_ids should be specified")
136+
return data
137+
138+
@field_validator("config_ids", mode="before")
139+
@classmethod
140+
def ensure_config_ids(cls, v: Any, info: ValidationInfo) -> Any:
141+
if v is None and info.data.get("config_id"):
142+
return [info.data["config_id"]]
143+
return v
144+
145+
146+
class GuardrailsChatCompletionRequest(OpenAIChatCompletionRequest):
147+
"""OpenAI chat completion request with NeMo Guardrails extensions."""
148+
149+
guardrails: GuardrailsDataInput = Field(
150+
default_factory=GuardrailsDataInput,
151+
description="Guardrails specific options for the request.",
152+
)

0 commit comments

Comments
 (0)