Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/reference/api-server-endpoints/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,69 @@ Guardrails-specific fields are nested under the `guardrails` object in the reque
- A state object to continue a previous interaction. Must contain an `events` or `state` key, or be an empty dict `{}` to start a new conversation.
```

### Authentication Headers

The server supports per-request API key injection via custom HTTP headers. This allows different requests to use different API keys for the configured LLM models, without modifying the server configuration or environment variables.

#### Header Format

For each model in your guardrails configuration, you can provide a custom API key using a header in the format:

```
X-{model-name}-Authorization: your-api-key-here
```

The header name is **case-insensitive** and the model name should match the `model` field in your configuration (spaces and special characters should be preserved as-is, though the header matching is case-insensitive).

#### Examples

**Single Model Configuration**

If your configuration uses `gpt-3.5-turbo` as the main model:

```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "X-Gpt-3.5-Turbo-Authorization: sk-custom-key-123" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
"guardrails": {"config_id": "my-config"}
}'
```

**Multi-Model Configuration**

If your configuration uses multiple models (e.g., `gpt-3.5-turbo` for main generation and `gpt-4` for self-check), you can provide separate keys for each:

```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "X-Gpt-3.5-Turbo-Authorization: sk-main-key-789" \
-H "X-Gpt-4-Authorization: sk-selfcheck-key-012" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
"guardrails": {"config_id": "my-config"}
}'
```

#### Behavior

- Headers are matched to models by comparing the model name (case-insensitive)
- If a header is provided for a model, it **overrides** the API key configured in the guardrails configuration or environment variables for that specific request only
- If no header is provided for a model, the default API key from the configuration is used
- API keys are automatically reset to their original values after each request completes, preventing leakage between requests
- This works for both streaming and non-streaming requests

#### Use Cases

This feature is particularly useful for:
- **Multi-tenant applications**: Different users can use their own API keys without server reconfiguration
- **Cost tracking**: Route different requests to different API accounts for billing purposes
- **A/B testing**: Test different API keys or accounts within the same deployment
- **Development**: Test with personal API keys without modifying shared configurations

### Generation Options

The `guardrails.options` field controls which rails are applied and what information is returned.
Expand Down
69 changes: 66 additions & 3 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,59 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon
return config.model_copy(update={"models": models})


def _set_api_keys(llm_rails: LLMRails, headers: dict):
"""Create temporary versions of all LLMRails models that use header API keys, if needed

Args:
llm_rails (LLMRails): LLMRails object used in request
headers (dict): API headers received in request
"""
original_model_config = [config.model_copy(deep=True) for config in llm_rails.config.models]
headers_lower = {k.lower(): v for k, v in headers.items()}
any_matched = False

for i, model in enumerate(original_model_config):
if model.model is None:
continue
target_header = f"x-{model.model.lower()}-authorization"
if headers_lower.get(target_header):
any_matched = True
llm_rails.config.models[i].parameters["api_key"] = headers_lower[target_header]
llm_rails.config.models[i].api_key_env_var = None
model_name = f"{model.type}_llm"
if hasattr(llm_rails, model_name) and model.type != "main":
delattr(llm_rails, model_name) # clear the initialized LLMs to force a reinit

if any_matched:
llm_rails.llm = None # clear the initialized LLMs to force a reinit
setattr(llm_rails, "original_config", original_model_config) # store backup of original config
llm_rails._init_llms()


def _reset_api_keys(llm_rails: LLMRails):
"""Reset API keys to their original values after request completes.

Args:
llm_rails (LLMRails): LLMRails object used in request
"""

if hasattr(llm_rails, "original_config"):
# restore backup config
llm_rails.config.models = getattr(llm_rails, "original_config")
llm_rails.llm = None

# Delete all task-specific LLMs so they get reinitialized with original API keys
for model_config in getattr(llm_rails, "original_config", []):
if model_config.type != "main":
model_name = f"{model_config.type}_llm"
if hasattr(llm_rails, model_name):
delattr(llm_rails, model_name)

# Remove the config backup so we don't unneccesarily call a reset
delattr(llm_rails, "original_config")
llm_rails._init_llms()


def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails:
"""Returns the rails instance for the given config id and model.

Expand Down Expand Up @@ -381,14 +434,15 @@ class ChunkError(BaseModel):


async def _format_streaming_response(
stream_iterator: AsyncIterator[Union[str, dict]], model_name: str
stream_iterator: AsyncIterator[Union[str, dict]], model_name: str, llm_rails: Optional[LLMRails] = None
) -> AsyncIterator[str]:
"""
Format streaming chunks from LLMRails.stream_async() as SSE events.

Args:
stream_iterator: AsyncIterator from stream_async() that yields str or dict chunks
model_name: The model name to include in the chunks
llm_rails: Optional LLMRails instance to reset API keys after streaming completes

Yields:
SSE-formatted strings (data: {...}\n\n)
Expand All @@ -409,6 +463,10 @@ async def _format_streaming_response(
yield format_streaming_chunk_as_sse(processed_chunk, model, chunk_id)

finally:
# Reset API keys to original values after streaming completes
if llm_rails is not None:
_reset_api_keys(llm_rails)

# Always send [DONE] event when stream ends
yield "data: [DONE]\n\n"

Expand Down Expand Up @@ -488,6 +546,8 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
)

try:
_set_api_keys(llm_rails, dict(request.headers))

messages = body.messages or []
if body.guardrails.context:
messages.insert(0, {"role": "context", "content": body.guardrails.context})
Expand Down Expand Up @@ -551,7 +611,7 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
)

return StreamingResponse(
_format_streaming_response(stream_iterator, model_name=body.model),
_format_streaming_response(stream_iterator, model_name=body.model, llm_rails=llm_rails),
media_type="text/event-stream",
)
else:
Expand Down Expand Up @@ -595,7 +655,6 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
)
],
)

except HTTPException:
raise
except Exception as ex:
Expand All @@ -605,6 +664,10 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
error_message="Internal server error",
config_id=config_ids[0] if config_ids else None,
)
finally:
# Reset API keys to original values after generation completes - stream cleanup handled separately
if not body.stream:
_reset_api_keys(llm_rails)


# By default, there are no challenges
Expand Down
Loading