diff --git a/app/openai_constants.py b/app/openai_constants.py index bc5a8b5..e8103ad 100644 --- a/app/openai_constants.py +++ b/app/openai_constants.py @@ -29,6 +29,9 @@ GPT_4_1_NANO_2025_04_14_MODEL = "gpt-4.1-nano-2025-04-14" GPT_5_CHAT_LATEST_MODEL = "gpt-5-chat-latest" +# Default model used for token counting when none specified +DEFAULT_TOKEN_COUNT_MODEL = GPT_3_5_TURBO_0613_MODEL + # Tuple: (tokens_per_message, tokens_per_name) MODEL_TOKENS = { # GPT-3.5 @@ -77,3 +80,50 @@ GPT_4_1_MINI_MODEL: GPT_4_1_MINI_2025_04_14_MODEL, GPT_4_1_NANO_MODEL: GPT_4_1_NANO_2025_04_14_MODEL, } + +MODEL_CONTEXT_LENGTHS = { + # GPT-3.5 + GPT_3_5_TURBO_0301_MODEL: 4096, + GPT_3_5_TURBO_0613_MODEL: 4096, + GPT_3_5_TURBO_16K_0613_MODEL: 16384, + GPT_3_5_TURBO_1106_MODEL: 16384, + GPT_3_5_TURBO_0125_MODEL: 16384, + # GPT-4 + GPT_4_0314_MODEL: 8192, + GPT_4_0613_MODEL: 8192, + GPT_4_32K_0314_MODEL: 32768, + GPT_4_32K_0613_MODEL: 32768, + GPT_4_1106_PREVIEW_MODEL: 128000, + GPT_4_0125_PREVIEW_MODEL: 128000, + GPT_4_TURBO_PREVIEW_MODEL: 128000, # GPT_4_TURBO_PREVIEW_MODEL is an alias for GPT_4_0125_PREVIEW_MODEL + GPT_4_TURBO_2024_04_09_MODEL: 128000, + # GPT-4o + GPT_4O_2024_05_13_MODEL: 128000, + # GPT-4o mini + GPT_4O_MINI_2024_07_18_MODEL: 128000, + # GPT-4.1 family + GPT_4_1_2025_04_14_MODEL: 1048576, + GPT_4_1_MINI_2025_04_14_MODEL: 1048576, + GPT_4_1_NANO_2025_04_14_MODEL: 1048576, + # GPT-5 chat latest + GPT_5_CHAT_LATEST_MODEL: 128000, +} + + +def resolve_model_alias(model: str) -> str: + """Resolves a model alias to a concrete version using MODEL_FALLBACKS. + + Raises ValueError on circular dependency. + Returns the input when no fallback mapping is found. + """ + if model is None: + return model + visited = {model} + while model in MODEL_FALLBACKS: + model = MODEL_FALLBACKS[model] + if model in visited: + raise ValueError( + f"Circular dependency detected in MODEL_FALLBACKS for model {model}" + ) + visited.add(model) + return model diff --git a/app/openai_ops.py b/app/openai_ops.py index 38ed2df..a5abf45 100644 --- a/app/openai_ops.py +++ b/app/openai_ops.py @@ -17,37 +17,10 @@ from app.markdown_conversion import slack_to_markdown, markdown_to_slack from app.openai_constants import ( MAX_TOKENS, - GPT_3_5_TURBO_MODEL, - GPT_3_5_TURBO_0301_MODEL, - GPT_3_5_TURBO_0613_MODEL, - GPT_3_5_TURBO_1106_MODEL, - GPT_3_5_TURBO_0125_MODEL, - GPT_3_5_TURBO_16K_MODEL, - GPT_3_5_TURBO_16K_0613_MODEL, - GPT_4_MODEL, - GPT_4_0314_MODEL, - GPT_4_0613_MODEL, - GPT_4_1106_PREVIEW_MODEL, - GPT_4_0125_PREVIEW_MODEL, - GPT_4_TURBO_PREVIEW_MODEL, - GPT_4_TURBO_MODEL, - GPT_4_TURBO_2024_04_09_MODEL, - GPT_4_32K_MODEL, - GPT_4_32K_0314_MODEL, - GPT_4_32K_0613_MODEL, - GPT_4O_MODEL, - GPT_4O_2024_05_13_MODEL, - GPT_4O_MINI_MODEL, - GPT_4O_MINI_2024_07_18_MODEL, - GPT_4_1_MODEL, - GPT_4_1_2025_04_14_MODEL, - GPT_4_1_MINI_MODEL, - GPT_4_1_MINI_2025_04_14_MODEL, - GPT_4_1_NANO_MODEL, - GPT_4_1_NANO_2025_04_14_MODEL, - GPT_5_CHAT_LATEST_MODEL, MODEL_TOKENS, - MODEL_FALLBACKS, + MODEL_CONTEXT_LENGTHS, + resolve_model_alias, + DEFAULT_TOKEN_COUNT_MODEL, ) from app.slack_ops import update_wip_message @@ -87,7 +60,11 @@ def messages_within_context_window( if context.get("OPENAI_FUNCTION_CALL_MODULE_NAME") is not None: max_context_tokens -= calculate_tokens_necessary_for_function_call(context) num_context_tokens = 0 # Number of tokens in the context window just before the earliest message is deleted - while (num_tokens := calculate_num_tokens(messages)) > max_context_tokens: + while ( + num_tokens := calculate_num_tokens( + messages, model=context.get("OPENAI_MODEL") + ) + ) > max_context_tokens: removed = False for i, message in enumerate(messages): if message["role"] in ("user", "assistant", "function"): @@ -330,72 +307,14 @@ def update_message(): def context_length( model: str, ) -> int: - if model == GPT_3_5_TURBO_MODEL: - # Note that GPT_3_5_TURBO_MODEL may change over time. Return context length assuming GPT_3_5_TURBO_0125_MODEL. - return context_length(model=GPT_3_5_TURBO_0125_MODEL) - if model == GPT_3_5_TURBO_16K_MODEL: - # Note that GPT_3_5_TURBO_16K_MODEL may change over time. Return context length assuming GPT_3_5_TURBO_16K_0613_MODEL. - return context_length(model=GPT_3_5_TURBO_16K_0613_MODEL) - elif model == GPT_4_MODEL: - # Note that GPT_4_MODEL may change over time. Return context length assuming GPT_4_0613_MODEL. - return context_length(model=GPT_4_0613_MODEL) - elif model == GPT_4_32K_MODEL: - # Note that GPT_4_32K_MODEL may change over time. Return context length assuming GPT_4_32K_0613_MODEL. - return context_length(model=GPT_4_32K_0613_MODEL) - elif model == GPT_4_TURBO_PREVIEW_MODEL: - # Note that GPT_4_TURBO_PREVIEW_MODEL may change over time. Return context length assuming GPT_4_0125_PREVIEW_MODEL. - return context_length(model=GPT_4_0125_PREVIEW_MODEL) - elif model == GPT_4_TURBO_MODEL: - # Note that GPT_4_TURBO_MODEL may change over time. Return context length assuming GPT_4_TURBO_2024_04_09_MODEL. - return context_length(model=GPT_4_TURBO_2024_04_09_MODEL) - elif model == GPT_4O_MODEL: - # Note that GPT_4O_MODEL may change over time. Return context length assuming GPT_4O_2024_05_13_MODEL. - return context_length(model=GPT_4O_2024_05_13_MODEL) - elif model == GPT_4O_MINI_MODEL: - # Note that GPT_4O_MINI_MODEL may change over time. Return context length assuming GPT_4O_MINI_2024_07_18_MODEL. - return context_length(model=GPT_4O_MINI_2024_07_18_MODEL) - elif model == GPT_4_1_MODEL: - # Note that GPT_4_1_MODEL may change over time. Return context length assuming GPT_4_1_2025_04_14_MODEL. - return context_length(model=GPT_4_1_2025_04_14_MODEL) - elif model == GPT_4_1_MINI_MODEL: - # Note that GPT_4_1_MINI_MODEL may change over time. Return context length assuming GPT_4_1_MINI_2025_04_14_MODEL. - return context_length(model=GPT_4_1_MINI_2025_04_14_MODEL) - elif model == GPT_4_1_NANO_MODEL: - # Note that GPT_4_1_NANO_MODEL may change over time. Return context length assuming GPT_4_1_NANO_2025_04_14_MODEL. - return context_length(model=GPT_4_1_NANO_2025_04_14_MODEL) - elif model == GPT_3_5_TURBO_0301_MODEL or model == GPT_3_5_TURBO_0613_MODEL: - return 4096 - elif ( - model == GPT_3_5_TURBO_16K_0613_MODEL - or model == GPT_3_5_TURBO_1106_MODEL - or model == GPT_3_5_TURBO_0125_MODEL - ): - return 16384 - elif model == GPT_4_0314_MODEL or model == GPT_4_0613_MODEL: - return 8192 - elif model == GPT_4_32K_0314_MODEL or model == GPT_4_32K_0613_MODEL: - return 32768 - elif ( - model == GPT_4_1_MODEL - or model == GPT_4_1_2025_04_14_MODEL - or model == GPT_4_1_MINI_MODEL - or model == GPT_4_1_MINI_2025_04_14_MODEL - or model == GPT_4_1_NANO_MODEL - or model == GPT_4_1_NANO_2025_04_14_MODEL - ): - return 1048576 - elif ( - model == GPT_4_1106_PREVIEW_MODEL - or model == GPT_4_0125_PREVIEW_MODEL - or model == GPT_4_TURBO_2024_04_09_MODEL - or model == GPT_4O_2024_05_13_MODEL - or model == GPT_4O_MINI_2024_07_18_MODEL - or model == GPT_5_CHAT_LATEST_MODEL - ): - return 128000 - else: - error = f"Calculating the length of the context window for model {model} is not yet supported." - raise NotImplementedError(error) + """Returns the context length for a given model.""" + actual_model = resolve_model_alias(model) + length = MODEL_CONTEXT_LENGTHS.get(actual_model) + if length is not None: + return length + + error = f"Calculating the length of the context window for model {actual_model} is not yet supported." + raise NotImplementedError(error) def encode_and_count_tokens( @@ -420,26 +339,19 @@ def encode_and_count_tokens( # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def calculate_num_tokens( messages: List[Dict[str, Union[str, Dict[str, str], List[Dict[str, str]]]]], - model: str = GPT_3_5_TURBO_0613_MODEL, + model: Optional[str] = None, ) -> int: """Returns the number of tokens used by a list of messages.""" + actual_model = resolve_model_alias(model or DEFAULT_TOKEN_COUNT_MODEL) try: - encoding = tiktoken.encoding_for_model(model) + encoding = tiktoken.encoding_for_model(actual_model) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = 0 - # Handle model-specific tokens per message and name - model_tokens: Optional[Tuple[int, int]] = MODEL_TOKENS.get(model, None) + model_tokens = MODEL_TOKENS.get(actual_model) if model_tokens is None: - fallback_result = None - if model in MODEL_FALLBACKS: - actual_model = MODEL_FALLBACKS[model] - fallback_result = calculate_num_tokens(messages, model=actual_model) - if fallback_result is not None: - return fallback_result error = ( - f"Calculating the number of tokens for model {model} is not yet supported. " + f"Calculating the number of tokens for model {actual_model} is not yet supported. " "See https://github.com/openai/openai-python/blob/main/chatml.md " "for information on how messages are converted to tokens." ) @@ -447,6 +359,7 @@ def calculate_num_tokens( tokens_per_message, tokens_per_name = model_tokens + num_tokens = 0 for message in messages: num_tokens += tokens_per_message for key, value in message.items(): diff --git a/tests/model_constants_test.py b/tests/model_constants_test.py new file mode 100644 index 0000000..85986f3 --- /dev/null +++ b/tests/model_constants_test.py @@ -0,0 +1,48 @@ +import pytest +from app.openai_constants import ( + resolve_model_alias, + MODEL_FALLBACKS, + MODEL_TOKENS, + MODEL_CONTEXT_LENGTHS, + GPT_4_MODEL, + GPT_4_0613_MODEL, +) + +def test_alias_resolution(): + """Tests that a model alias resolves to its specific version.""" + assert resolve_model_alias(GPT_4_MODEL) == GPT_4_0613_MODEL + +def test_unregistered_model_fails(): + """Tests that resolving an unregistered model raises NotImplementedError.""" + # First, test the resolver + unregistered_model = "this-model-does-not-exist" + assert resolve_model_alias(unregistered_model) == unregistered_model + + # Then, test the functions that use the resolver + from app.openai_ops import context_length, calculate_num_tokens + with pytest.raises(NotImplementedError): + context_length(unregistered_model) + with pytest.raises(NotImplementedError): + calculate_num_tokens(messages=[], model=unregistered_model) + +def test_circular_fallback_fails(monkeypatch): + """Tests that a circular dependency in fallbacks raises a ValueError.""" + # Temporarily introduce a circular dependency for testing + monkeypatch.setitem(MODEL_FALLBACKS, "model_a", "model_b") + monkeypatch.setitem(MODEL_FALLBACKS, "model_b", "model_a") + + with pytest.raises(ValueError, match="Circular dependency detected"): + resolve_model_alias("model_a") + +def test_model_coverage(): + """ + Tests that all models in FALLBACKS can be resolved to a model + with defined tokens and context length. + """ + for alias in MODEL_FALLBACKS.keys(): + try: + resolved_model = resolve_model_alias(alias) + assert resolved_model in MODEL_TOKENS + assert resolved_model in MODEL_CONTEXT_LENGTHS + except Exception as e: + pytest.fail(f"Failed to resolve or find definitions for model alias {alias}: {e}") diff --git a/tests/openai_ops_test.py b/tests/openai_ops_test.py index 6cfa403..e76d07b 100755 --- a/tests/openai_ops_test.py +++ b/tests/openai_ops_test.py @@ -1,7 +1,9 @@ +import app.openai_ops as ops from app.openai_ops import ( format_assistant_reply, format_openai_message_content, ) +from app.openai_constants import GPT_4O_MODEL def test_format_assistant_reply(): @@ -69,3 +71,28 @@ def test_format_openai_message_content(): ]: result = format_openai_message_content(content, False) assert result == expected + + +def test_messages_within_context_window_passes_model(monkeypatch): + """Ensures token counting receives the actual OPENAI_MODEL from context.""" + captured = {"model": None, "calls": 0} + + def fake_calculate_num_tokens(messages, model=None): # type: ignore[no-redef] + captured["model"] = model + captured["calls"] += 1 + return 0 # Keep under threshold to avoid loop iterations + + monkeypatch.setattr(ops, "calculate_num_tokens", fake_calculate_num_tokens) + + messages = [{"role": "user", "content": "hi"}] + context = { + "OPENAI_MODEL": GPT_4O_MODEL, + "OPENAI_FUNCTION_CALL_MODULE_NAME": None, + } + + # Execute + ops.messages_within_context_window(messages, context) # type: ignore[arg-type] + + # Assert the model used for token counting matches context + assert captured["calls"] >= 1 + assert captured["model"] == GPT_4O_MODEL