Skip to content

Commit d37cfe8

Browse files
committed
Add support for request-time specification of model API keys in the server
Signed-off-by: Rob Geada <rob@geada.net>
1 parent 7882262 commit d37cfe8

File tree

5 files changed

+841
-2
lines changed

5 files changed

+841
-2
lines changed

docs/reference/api-server-endpoints/index.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,69 @@ Guardrails-specific fields are nested under the `guardrails` object in the reque
151151
- A state object to continue a previous interaction. Must contain an `events` or `state` key, or be an empty dict `{}` to start a new conversation.
152152
```
153153

154+
### Authentication Headers
155+
156+
The server supports per-request API key injection via custom HTTP headers. This allows different requests to use different API keys for the configured LLM models, without modifying the server configuration or environment variables.
157+
158+
#### Header Format
159+
160+
For each model in your guardrails configuration, you can provide a custom API key using a header in the format:
161+
162+
```
163+
X-{model-name}-Authorization: your-api-key-here
164+
```
165+
166+
The header name is **case-insensitive** and the model name should match the `model` field in your configuration (spaces and special characters should be preserved as-is, though the header matching is case-insensitive).
167+
168+
#### Examples
169+
170+
**Single Model Configuration**
171+
172+
If your configuration uses `gpt-3.5-turbo` as the main model:
173+
174+
```bash
175+
curl -X POST http://localhost:8000/v1/chat/completions \
176+
-H "Content-Type: application/json" \
177+
-H "X-Gpt-3.5-Turbo-Authorization: sk-custom-key-123" \
178+
-d '{
179+
"model": "gpt-3.5-turbo",
180+
"messages": [{"role": "user", "content": "Hello"}],
181+
"guardrails": {"config_id": "my-config"}
182+
}'
183+
```
184+
185+
**Multi-Model Configuration**
186+
187+
If your configuration uses multiple models (e.g., `gpt-3.5-turbo` for main generation and `gpt-4` for self-check), you can provide separate keys for each:
188+
189+
```bash
190+
curl -X POST http://localhost:8000/v1/chat/completions \
191+
-H "Content-Type: application/json" \
192+
-H "X-Gpt-3.5-Turbo-Authorization: sk-main-key-789" \
193+
-H "X-Gpt-4-Authorization: sk-selfcheck-key-012" \
194+
-d '{
195+
"model": "gpt-3.5-turbo",
196+
"messages": [{"role": "user", "content": "Hello"}],
197+
"guardrails": {"config_id": "my-config"}
198+
}'
199+
```
200+
201+
#### Behavior
202+
203+
- Headers are matched to models by comparing the model name (case-insensitive)
204+
- If a header is provided for a model, it **overrides** the API key configured in the guardrails configuration or environment variables for that specific request only
205+
- If no header is provided for a model, the default API key from the configuration is used
206+
- API keys are automatically reset to their original values after each request completes, preventing leakage between requests
207+
- This works for both streaming and non-streaming requests
208+
209+
#### Use Cases
210+
211+
This feature is particularly useful for:
212+
- **Multi-tenant applications**: Different users can use their own API keys without server reconfiguration
213+
- **Cost tracking**: Route different requests to different API accounts for billing purposes
214+
- **A/B testing**: Test different API keys or accounts within the same deployment
215+
- **Development**: Test with personal API keys without modifying shared configurations
216+
154217
### Generation Options
155218

156219
The `guardrails.options` field controls which rails are applied and what information is returned.

nemoguardrails/server/api.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,58 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon
302302
return config.model_copy(update={"models": models})
303303

304304

305+
def _set_api_keys(llm_rails: LLMRails, headers: dict):
306+
"""Create temporary versions of all LLMRails models that use header API keys, if needed
307+
308+
Args:
309+
llm_rails (LLMRails): LLMRails object used in request
310+
headers (dict): API headers received in request
311+
"""
312+
original_model_config = [config.model_copy(deep=True) for config in llm_rails.config.models]
313+
headers_lower = {k.lower(): v for k, v in headers.items()}
314+
any_matched = False
315+
316+
for i, model in enumerate(original_model_config):
317+
if model.model is None:
318+
continue
319+
target_header = f"x-{model.model.lower()}-authorization"
320+
if target_header in headers_lower:
321+
any_matched = True
322+
llm_rails.config.models[i].parameters["api_key"] = headers_lower[target_header]
323+
llm_rails.config.models[i].api_key_env_var = None
324+
model_name = f"{model.type}_llm"
325+
if hasattr(llm_rails, model_name) and model.type != "main":
326+
delattr(llm_rails, model_name) # clear the initialized LLMs to force a reinit
327+
328+
if any_matched:
329+
llm_rails.llm = None # clear the initialized LLMs to force a reinit
330+
setattr(llm_rails, "original_config", original_model_config) # store backup of original config
331+
llm_rails._init_llms()
332+
333+
334+
def _reset_api_keys(llm_rails: LLMRails):
335+
"""Reset API keys to their original values after request completes.
336+
337+
Args:
338+
llm_rails (LLMRails): LLMRails object used in request
339+
"""
340+
341+
if hasattr(llm_rails, "original_config"):
342+
# restore backup config
343+
llm_rails.config.models = getattr(llm_rails, "original_config")
344+
llm_rails.llm = None
345+
346+
# Delete all task-specific LLMs so they get reinitialized with original API keys
347+
for model_config in getattr(llm_rails, "original_config", []):
348+
if model_config.type != "main":
349+
model_name = f"{model_config.type}_llm"
350+
if hasattr(llm_rails, model_name):
351+
delattr(llm_rails, model_name)
352+
353+
# Remove the config backup so we don't unneccesarily call a reset
354+
delattr(llm_rails, "original_config")
355+
356+
305357
def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails:
306358
"""Returns the rails instance for the given config id and model.
307359
@@ -381,14 +433,15 @@ class ChunkError(BaseModel):
381433

382434

383435
async def _format_streaming_response(
384-
stream_iterator: AsyncIterator[Union[str, dict]], model_name: str
436+
stream_iterator: AsyncIterator[Union[str, dict]], model_name: str, llm_rails: Optional[LLMRails] = None
385437
) -> AsyncIterator[str]:
386438
"""
387439
Format streaming chunks from LLMRails.stream_async() as SSE events.
388440
389441
Args:
390442
stream_iterator: AsyncIterator from stream_async() that yields str or dict chunks
391443
model_name: The model name to include in the chunks
444+
llm_rails: Optional LLMRails instance to reset API keys after streaming completes
392445
393446
Yields:
394447
SSE-formatted strings (data: {...}\n\n)
@@ -412,6 +465,10 @@ async def _format_streaming_response(
412465
# Always send [DONE] event when stream ends
413466
yield "data: [DONE]\n\n"
414467

468+
# Reset API keys to original values after streaming completes
469+
if llm_rails is not None:
470+
_reset_api_keys(llm_rails)
471+
415472

416473
def process_chunk(chunk: Any) -> Union[Any, ChunkError]:
417474
"""
@@ -487,6 +544,8 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
487544
config_id=config_ids[0] if config_ids else None,
488545
)
489546

547+
_set_api_keys(llm_rails, dict(request.headers))
548+
490549
try:
491550
messages = body.messages or []
492551
if body.guardrails.context:
@@ -551,7 +610,7 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
551610
)
552611

553612
return StreamingResponse(
554-
_format_streaming_response(stream_iterator, model_name=body.model),
613+
_format_streaming_response(stream_iterator, model_name=body.model, llm_rails=llm_rails),
555614
media_type="text/event-stream",
556615
)
557616
else:
@@ -569,6 +628,9 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
569628
if body.guardrails.thread_id and datastore is not None and datastore_key is not None:
570629
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
571630

631+
# clear injected api keys
632+
_reset_api_keys(llm_rails)
633+
572634
# Build the response with OpenAI-compatible format using utility function
573635
if isinstance(res, GenerationResponse):
574636
return generation_response_to_chat_completion(
@@ -597,8 +659,10 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
597659
)
598660

599661
except HTTPException:
662+
_reset_api_keys(llm_rails)
600663
raise
601664
except Exception as ex:
665+
_reset_api_keys(llm_rails)
602666
log.exception(ex)
603667
return create_error_chat_completion(
604668
model=body.model,

0 commit comments

Comments
 (0)