Skip to content

Commit 64ac9d4

Browse files
committed
Add support for request-time specification of model API keys
Signed-off-by: Rob Geada <rob@geada.net>
1 parent 0287a38 commit 64ac9d4

File tree

8 files changed

+494
-11
lines changed

8 files changed

+494
-11
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Example configuration showing how to extract API keys from HTTP request headers
2+
#
3+
# This is useful when you want to use different API keys for different users/requests,
4+
# rather than a single API key for all requests.
5+
#
6+
# To use, add the api_key_header to your model config e.g.:
7+
# api_key_header: "X-API-Key"
8+
# Then, when a request comes in with a matching header:
9+
# X-API-Key: sk-abc123xyz
10+
#
11+
# The value "sk-abc123xyz" will be used as the Bearer token for the LLM API call.
12+
13+
models:
14+
- type: main
15+
engine: openai
16+
model: gpt-4
17+
# Instead of using api_key_env_var (which loads from environment variable),
18+
# use api_key_header to extract the API key from each request's headers
19+
api_key_header: "X-API-Key" # The name of the HTTP header containing the API key
20+
21+
# You can also use this with other model types
22+
# - type: content_safety
23+
# engine: openai
24+
# model: gpt-3.5-turbo
25+
# api_key_header: "X-API-Key"
26+
27+
rails:
28+
input:
29+
flows:
30+
- self check input
31+
32+
output:
33+
flows:
34+
- self check output

nemoguardrails/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818

1919
from nemoguardrails.logging.explain import LLMCallInfo
2020
from nemoguardrails.rails.llm.options import GenerationOptions
21-
from nemoguardrails.streaming import StreamingHandler
2221

23-
streaming_handler_var: contextvars.ContextVar[Optional[StreamingHandler]] = contextvars.ContextVar(
24-
"streaming_handler", default=None
25-
)
2622
if TYPE_CHECKING:
2723
from nemoguardrails.logging.explain import ExplainInfo
2824
from nemoguardrails.logging.stats import LLMStats
@@ -62,3 +58,7 @@
6258
llm_response_metadata_var: contextvars.ContextVar[Optional[dict]] = contextvars.ContextVar(
6359
"llm_response_metadata", default=None
6460
)
61+
62+
# The HTTP request headers for API requests.
63+
# This is used to extract API keys or other authentication tokens from headers.
64+
api_request_headers: contextvars.ContextVar = contextvars.ContextVar("api_request_headers", default=None)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Module for initializing LLM models with proper error handling and type checking."""
17+
18+
import logging
19+
from typing import Any, List, Optional
20+
21+
from langchain_core.language_models import BaseChatModel
22+
from langchain_core.messages import AIMessage
23+
from pydantic import ConfigDict, Field
24+
25+
from nemoguardrails.context import api_request_headers
26+
27+
log = logging.getLogger(__name__)
28+
29+
30+
class HeaderAPIKeyWrapper(BaseChatModel):
31+
"""Wrapper that injects API keys from request headers at runtime.
32+
33+
This wrapper intercepts LLM calls and reads the API key from the HTTP request
34+
headers (via the api_request_headers context variable from server.api) on every request.
35+
36+
From testing, this adds negligible time to each LLM call (~1e-06 seconds)
37+
"""
38+
39+
wrapped_llm: BaseChatModel = Field(description="The LangChain LLM to wrap")
40+
api_key_header: str = Field(description="The name of the HTTP header containing the API key")
41+
model_config = ConfigDict(arbitrary_types_allowed=True)
42+
43+
def __init__(self, llm: BaseChatModel, api_key_header: str, **kwargs):
44+
"""Initialize the wrapper.
45+
46+
Args:
47+
llm: The LangChain LLM to wrap (must be a BaseChatModel)
48+
api_key_header: The name of the HTTP header containing the API key
49+
"""
50+
# Initialize with the data dict for Pydantic
51+
super().__init__(**{"wrapped_llm": llm, "api_key_header": api_key_header, **kwargs})
52+
53+
def _get_api_key_from_headers(self) -> Optional[str]:
54+
"""Extract API key from the current request headers."""
55+
try:
56+
headers = api_request_headers.get(None)
57+
if headers and self.api_key_header in headers:
58+
return headers[self.api_key_header]
59+
except LookupError:
60+
# Context variable not set (e.g., not in a server request context)
61+
pass
62+
return None
63+
64+
def _get_llm_with_api_key(self, api_key: Optional[str]) -> BaseChatModel:
65+
"""Get LLM instance with the specified API key.
66+
67+
Creates a new LLM instance if api_key is provided, otherwise returns
68+
the wrapped LLM. This ensures thread-safety by avoiding shared state mutation.
69+
"""
70+
if not api_key:
71+
return self.wrapped_llm
72+
73+
# Try ChatOpenAI-specific approach (most common)
74+
try:
75+
from langchain_openai import ChatOpenAI # type: ignore[import-not-found]
76+
77+
if isinstance(self.wrapped_llm, ChatOpenAI):
78+
# Create a shallow copy with the new API key
79+
# We use model_dump() to get all current settings, then override the API key
80+
config = self.wrapped_llm.model_dump()
81+
# Try both parameter names for compatibility
82+
config["openai_api_key"] = api_key
83+
config["api_key"] = api_key
84+
return ChatOpenAI(**config)
85+
except Exception as e:
86+
log.warning(f"Failed to create ChatOpenAI with custom API key: {e}")
87+
88+
# Fallback: Try generic model_copy for other providers
89+
if hasattr(self.wrapped_llm, "model_copy"):
90+
for key_param in ["api_key", "anthropic_api_key", "cohere_api_key"]:
91+
try:
92+
return self.wrapped_llm.model_copy(update={key_param: api_key})
93+
except (TypeError, ValueError):
94+
continue
95+
96+
# If all fails, log warning and use default
97+
log.warning(
98+
f"Unable to create new instance for {type(self.wrapped_llm).__name__}. "
99+
f"Using default API key. Multi-tenant isolation not available for this provider."
100+
)
101+
return self.wrapped_llm
102+
103+
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
104+
"""Generate response using the wrapped LLM with runtime API key."""
105+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
106+
return llm._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
107+
108+
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
109+
"""Async generate response using the wrapped LLM with runtime API key."""
110+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
111+
return await llm._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)
112+
113+
def invoke(
114+
self,
115+
input: Any,
116+
config: Optional[Any] = None,
117+
*,
118+
stop: Optional[List[str]] = None,
119+
**kwargs: Any,
120+
) -> AIMessage:
121+
"""Invoke the LLM with runtime API key from headers."""
122+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
123+
return llm.invoke(input, config=config, stop=stop, **kwargs)
124+
125+
async def ainvoke(
126+
self,
127+
input: Any,
128+
config: Optional[Any] = None,
129+
*,
130+
stop: Optional[List[str]] = None,
131+
**kwargs: Any,
132+
) -> AIMessage:
133+
"""Async invoke the LLM with runtime API key from headers."""
134+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
135+
return await llm.ainvoke(input, config=config, stop=stop, **kwargs)
136+
137+
def _stream(self, messages, stop=None, run_manager=None, **kwargs):
138+
"""Stream response using the wrapped LLM with runtime API key."""
139+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
140+
yield from llm._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
141+
142+
async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
143+
"""Async stream response using the wrapped LLM with runtime API key."""
144+
llm = self._get_llm_with_api_key(self._get_api_key_from_headers())
145+
async for chunk in llm._astream(messages, stop=stop, run_manager=run_manager, **kwargs):
146+
yield chunk
147+
148+
@property
149+
def _llm_type(self) -> str:
150+
"""Return the LLM type."""
151+
return f"header_api_key_wrapper_{self.wrapped_llm._llm_type}"
152+
153+
154+
__all__ = ["HeaderAPIKeyWrapper"]

nemoguardrails/rails/llm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class Model(BaseModel):
118118
default=None,
119119
description='Optional environment variable with model\'s API Key. Do not include "$".',
120120
)
121+
api_key_header: Optional[str] = Field(
122+
default=None,
123+
description="Optional HTTP header name from which to extract the API key. The header value will be used as a Bearer token.",
124+
)
121125
parameters: Dict[str, Any] = Field(default_factory=dict)
122126

123127
mode: Literal["chat", "text"] = Field(

nemoguardrails/rails/llm/llmrails.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from nemoguardrails.kb.kb import KnowledgeBase
8080
from nemoguardrails.llm.cache import CacheInterface, LFUCache
81+
from nemoguardrails.llm.models.header_api_key_wrapper import HeaderAPIKeyWrapper
8182
from nemoguardrails.llm.models.initializer import (
8283
ModelInitializationError,
8384
init_llm_model,
@@ -420,6 +421,13 @@ def _init_llms(self):
420421
mode="chat",
421422
kwargs=kwargs,
422423
)
424+
425+
# Wrap with header-based API key wrapper if configured
426+
if main_model.api_key_header and isinstance(self.llm, BaseChatModel):
427+
log.info(
428+
f"Wrapping main LLM with header-based API key extraction from header: {main_model.api_key_header}"
429+
)
430+
self.llm = HeaderAPIKeyWrapper(self.llm, main_model.api_key_header)
423431
self.runtime.register_action_param("llm", self.llm)
424432

425433
else:
@@ -453,6 +461,13 @@ def _init_llms(self):
453461
kwargs=kwargs,
454462
)
455463

464+
# Wrap with header-based API key wrapper if configured
465+
if llm_config.api_key_header and isinstance(llm_model, BaseChatModel):
466+
log.info(
467+
f"Wrapping {llm_config.type} LLM with header-based API key extraction from header: {llm_config.api_key_header}"
468+
)
469+
llm_model = HeaderAPIKeyWrapper(llm_model, llm_config.api_key_header)
470+
456471
# Configure the model based on its type
457472
if llm_config.type == "main":
458473
# If a main LLM was already injected, skip creating another

nemoguardrails/server/api.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import asyncio
17-
import contextvars
1817
import importlib.util
1918
import json
2019
import logging
@@ -34,6 +33,7 @@
3433
from starlette.staticfiles import StaticFiles
3534

3635
from nemoguardrails import LLMRails, RailsConfig, utils
36+
from nemoguardrails.context import api_request_headers
3737
from nemoguardrails.rails.llm.config import Model
3838
from nemoguardrails.rails.llm.options import GenerationResponse
3939
from nemoguardrails.server.datastore.datastore import DataStore
@@ -75,9 +75,6 @@ def __init__(self, *args, **kwargs):
7575

7676
api_description = """Guardrails Sever API."""
7777

78-
# The headers for each request
79-
api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers")
80-
8178
# The datastore that the Server should use.
8279
# This is currently used only for storing threads.
8380
# TODO: refactor to wrap the FastAPI instance inside a RailsServer class
@@ -309,17 +306,33 @@ def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRa
309306
raise ValueError("No valid rails configuration found.")
310307

311308
if model_name:
309+
# Get engine from environment or use existing main model's engine
310+
existing_main_model = next((m for m in full_llm_rails_config.models if m.type == "main"), None)
311+
312312
engine = os.environ.get("MAIN_MODEL_ENGINE")
313-
if not engine:
313+
if not engine and existing_main_model:
314+
engine = existing_main_model.engine
315+
elif not engine:
314316
engine = "openai"
315-
log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")
317+
log.warning("No main model in config and MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")
316318

317319
parameters = {}
318320
base_url = os.environ.get("MAIN_MODEL_BASE_URL")
319321
if base_url:
320322
parameters["base_url"] = base_url
321323

322-
main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters)
324+
# Preserve api_key_header and api_key_env_var from existing config
325+
api_key_header = existing_main_model.api_key_header if existing_main_model else None
326+
api_key_env_var = existing_main_model.api_key_env_var if existing_main_model else None
327+
328+
main_model = Model(
329+
model=model_name,
330+
type="main",
331+
engine=engine,
332+
parameters=parameters,
333+
api_key_header=api_key_header,
334+
api_key_env_var=api_key_env_var,
335+
)
323336
full_llm_rails_config = _update_models_in_config(full_llm_rails_config, main_model)
324337

325338
llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)

0 commit comments

Comments
 (0)