Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Literal, cast

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self, override

from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
Expand All @@ -32,7 +32,10 @@
LangSmithParams,
LanguageModelInput,
)
from langchain_core.language_models.model_profile import ModelProfile
from langchain_core.language_models.model_profile import (
ModelProfile,
_warn_unknown_profile_keys,
)
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -357,6 +360,46 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
arbitrary_types_allowed=True,
)

def _resolve_model_profile(self) -> ModelProfile | None:
"""Resolve the default model profile for this model.

Override in subclasses to provide auto-populated profile data.

Subclasses that override this method do not need to define their own
`_set_model_profile` validator — the base class validator will call this
method automatically.

Returns:
A `ModelProfile` dict, or `None` if no default profile is available.
"""
return None

@model_validator(mode="after")
def _set_model_profile(self) -> Self:
"""Set model profile if not overridden.

Subclasses can either:

- Override `_resolve_model_profile` (recommended) and inherit this
validator, or
- Override this validator directly (existing behavior, replaces this
implementation in Pydantic v2).
"""
if self.profile is None:
self.profile = self._resolve_model_profile()
return self

@model_validator(mode="after")
def _check_profile_keys(self) -> Self:
"""Warn on unrecognized profile keys.

Uses a distinct method name so that partner subclasses that override
`_set_model_profile` do not inadvertently suppress this check.
"""
if self.profile:
_warn_unknown_profile_keys(self.profile)
return self

@cached_property
def _serialized(self) -> dict[str, Any]:
# self is always a Serializable object in this case, thus the result is
Expand Down
78 changes: 78 additions & 0 deletions libs/core/langchain_core/language_models/model_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Model profile types and utilities."""

import contextlib
import warnings
from typing import get_type_hints

from pydantic import ConfigDict
from typing_extensions import TypedDict


Expand All @@ -14,6 +19,25 @@ class ModelProfile(TypedDict, total=False):
and supported features.
"""

__pydantic_config__ = ConfigDict(extra="allow") # type: ignore[misc]

# --- Model metadata ---

name: str
"""Human-readable model name."""

status: str
"""Model status (e.g., `'active'`, `'deprecated'`)."""

release_date: str
"""Model release date (ISO 8601 format, e.g., `'2025-06-01'`)."""

last_updated: str
"""Date the model was last updated (ISO 8601 format)."""

open_weights: bool
"""Whether the model weights are openly available."""

# --- Input constraints ---

max_input_tokens: int
Expand Down Expand Up @@ -86,6 +110,60 @@ class ModelProfile(TypedDict, total=False):
"""Whether the model supports a native [structured output](https://docs.langchain.com/oss/python/langchain/models#structured-outputs)
feature"""

# --- Other capabilities ---

attachment: bool
"""Whether the model supports file attachments."""

temperature: bool
"""Whether the model supports a temperature parameter."""


ModelProfileRegistry = dict[str, ModelProfile]
"""Registry mapping model identifiers or names to their ModelProfile."""


# Cache for ModelProfile's declared field names. Populated lazily because
# _warn_unknown_profile_keys runs on every chat model construction and
# get_type_hints is not free.
_DECLARED_PROFILE_KEYS: frozenset[str] | None = None


def _get_declared_profile_keys() -> frozenset[str]:
"""Return the declared `ModelProfile` field names, cached after first call."""
global _DECLARED_PROFILE_KEYS # noqa: PLW0603
if _DECLARED_PROFILE_KEYS is None:
_DECLARED_PROFILE_KEYS = frozenset(get_type_hints(ModelProfile).keys())
return _DECLARED_PROFILE_KEYS


def _warn_unknown_profile_keys(profile: ModelProfile) -> None:
"""Emit a warning if a profile dict contains keys not declared in `ModelProfile`.

This function must never raise — it is called during model construction and
a failure here would prevent all chat model instantiation.

Args:
profile: Model profile dict to check.
"""
try:
declared = _get_declared_profile_keys()
except Exception:
# If introspection fails (e.g. forward ref issues), skip rather than
# crash model construction.
return

extra = sorted(set(profile) - declared)
if extra:
# warnings.warn() raises when the user (or a test framework like
# pytest) configures warnings-as-errors (-W error /
# warnings.simplefilter("error")). Suppress so we honour the
# "must never raise" contract — this runs during every chat model
# construction.
with contextlib.suppress(Exception):
warnings.warn(
f"Unrecognized keys in model profile: {extra}. "
f"This may indicate a version mismatch between langchain-core "
f"and your provider package. Consider upgrading langchain-core.",
stacklevel=2,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import TYPE_CHECKING, Any, Literal

import pytest
from typing_extensions import override
from pydantic import model_validator
from typing_extensions import Self, override

from langchain_core.callbacks import (
CallbackManagerForLLMRun,
Expand All @@ -22,6 +23,7 @@
FakeListChatModelError,
GenericFakeChatModel,
)
from langchain_core.language_models.model_profile import ModelProfile
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -1230,6 +1232,109 @@ def test_model_profiles() -> None:
assert model_with_profile.profile == {"max_input_tokens": 100}


def test_model_profile_extra_keys_accepted() -> None:
"""extra='allow' on ModelProfile means unknown keys don't crash."""
model = GenericFakeChatModel(
messages=iter([]),
profile={"max_input_tokens": 100, "unknown_future_field": True},
)
assert model.profile is not None
assert model.profile.get("unknown_future_field") is True


def test_check_profile_keys_warns_on_unknown() -> None:
"""_check_profile_keys validator warns for undeclared profile keys."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
GenericFakeChatModel(
messages=iter([]),
profile={"max_input_tokens": 100, "unknown_field": True},
)

profile_warnings = [x for x in w if "Unrecognized keys" in str(x.message)]
assert len(profile_warnings) == 1
assert "unknown_field" in str(profile_warnings[0].message)


def test_check_profile_keys_silent_on_valid() -> None:
"""_check_profile_keys validator does not warn for declared keys."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
GenericFakeChatModel(
messages=iter([]),
profile={"max_input_tokens": 100, "tool_calling": True},
)

profile_warnings = [x for x in w if "Unrecognized keys" in str(x.message)]
assert len(profile_warnings) == 0


def test_check_profile_keys_runs_despite_partner_override() -> None:
"""Verify _check_profile_keys fires even when _set_model_profile is overridden.

Uses a distinct validator name so partner overrides do not suppress it.
"""

class PartnerModel(GenericFakeChatModel):
"""Simulates a partner that overrides _set_model_profile."""

@model_validator(mode="after")
def _set_model_profile(self) -> Self:
if self.profile is None:
profile: dict[str, Any] = {
"max_input_tokens": 100,
"partner_only_field": True,
}
self.profile = profile # type: ignore[assignment]
return self

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model = PartnerModel(messages=iter([]))

assert model.profile is not None
assert model.profile.get("partner_only_field") is True
profile_warnings = [x for x in w if "Unrecognized keys" in str(x.message)]
assert len(profile_warnings) == 1
assert "partner_only_field" in str(profile_warnings[0].message)


def test_resolve_model_profile_auto_populates() -> None:
"""Base _set_model_profile validator auto-populates from _resolve_model_profile."""

class AutoProfileModel(GenericFakeChatModel):
def _resolve_model_profile(self) -> ModelProfile | None:
return {"max_input_tokens": 42, "tool_calling": True}

model = AutoProfileModel(messages=iter([]))
assert model.profile is not None
assert model.profile["max_input_tokens"] == 42
assert model.profile["tool_calling"] is True


def test_explicit_profile_not_overwritten_by_resolve() -> None:
"""Explicit profile= kwarg takes precedence over _resolve_model_profile."""

class AutoProfileModel(GenericFakeChatModel):
def _resolve_model_profile(self) -> ModelProfile | None:
return {"max_input_tokens": 42}

model = AutoProfileModel(messages=iter([]), profile={"max_input_tokens": 999})
assert model.profile is not None
assert model.profile["max_input_tokens"] == 999


def test_resolve_model_profile_none_leaves_profile_none() -> None:
"""Subclass returning None from _resolve_model_profile leaves profile as None."""

class NoProfileModel(GenericFakeChatModel):
def _resolve_model_profile(self) -> ModelProfile | None:
return None

model = NoProfileModel(messages=iter([]))
assert model.profile is None


class MockResponse:
"""Mock response for testing _generate_response_from_error."""

Expand Down
Loading
Loading