diff --git a/ecologits/tracers/anthropic_tracer.py b/ecologits/tracers/anthropic_tracer.py index 067426a5..57f0f839 100644 --- a/ecologits/tracers/anthropic_tracer.py +++ b/ecologits/tracers/anthropic_tracer.py @@ -127,12 +127,13 @@ class MessageStreamManager(Generic[MessageStreamT]): Re-writing of Anthropic's `MessageStreamManager` with wrapped `MessageStream` """ - def __init__(self, api_request: Callable[[], MessageStream]) -> None: + def __init__(self, api_request: Callable[[], Any], output_format: Any) -> None: self.__api_request = api_request + self.__output_format = output_format def __enter__(self) -> MessageStream: - self.__stream = self.__api_request() - self.__stream = MessageStream(self.__stream) + raw_stream = self.__api_request() + self.__stream = MessageStream(raw_stream, output_format=self.__output_format) return self.__stream def __exit__( @@ -149,12 +150,13 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]): """ Re-writing of Anthropic's `AsyncMessageStreamManager` with wrapped `AsyncMessageStream` """ - def __init__(self, api_request: Awaitable[AsyncMessageStream]) -> None: + def __init__(self, api_request: Awaitable[Any], output_format: Any) -> None: self.__api_request = api_request + self.__output_format = output_format async def __aenter__(self) -> AsyncMessageStream: - self.__stream = await self.__api_request - self.__stream = AsyncMessageStream(self.__stream) + raw_stream = await self.__api_request + self.__stream = AsyncMessageStream(raw_stream, output_format=self.__output_format) return self.__stream async def __aexit__( @@ -271,7 +273,9 @@ def anthropic_stream_chat_wrapper( A wrapped `MessageStreamManager` with impacts """ response = wrapped(*args, **kwargs) - return MessageStreamManager(response._MessageStreamManager__api_request) # noqa: SLF001 + api_request = response._MessageStreamManager__api_request # noqa: SLF001 + output_format = response._MessageStreamManager__output_format # noqa: SLF001 + return MessageStreamManager(api_request, output_format) def anthropic_async_stream_chat_wrapper( @@ -290,7 +294,9 @@ def anthropic_async_stream_chat_wrapper( A wrapped `AsyncMessageStreamManager` with impacts """ response = wrapped(*args, **kwargs) - return AsyncMessageStreamManager(response._AsyncMessageStreamManager__api_request) # noqa: SLF001 + api_request = response._AsyncMessageStreamManager__api_request # noqa: SLF001 + output_format = response._AsyncMessageStreamManager__output_format # noqa: SLF001 + return AsyncMessageStreamManager(api_request, output_format) class AnthropicInstrumentor: diff --git a/ecologits/tracers/huggingface_tracer.py b/ecologits/tracers/huggingface_tracer.py index 3484a584..37711283 100644 --- a/ecologits/tracers/huggingface_tracer.py +++ b/ecologits/tracers/huggingface_tracer.py @@ -13,6 +13,10 @@ from ecologits.tracers.utils import ImpactsOutput, llm_impacts PROVIDER = "huggingface_hub" +HF_INFERENCE_URL_PREFIXES = ( + "https://api-inference.huggingface.co/models/", + "https://router.huggingface.co/hf-inference/models/", +) @dataclass @@ -31,6 +35,18 @@ class ChatCompletionStreamOutput(_ChatCompletionStreamOutput): impacts: Optional[ImpactsOutput] = None +def _resolve_model_name(*values: Optional[str]) -> Optional[str]: + for value in values: + if value is None: + continue + for prefix in HF_INFERENCE_URL_PREFIXES: + if value.startswith(prefix): + return value.removeprefix(prefix).removesuffix("/v1/chat/completions") + return value + + return None + + def huggingface_chat_wrapper( wrapped: Callable, instance: InferenceClient, @@ -67,7 +83,10 @@ def huggingface_chat_wrapper_non_stream( request_latency = time.perf_counter() - timer_start output_tokens = response.usage["completion_tokens"] input_tokens = response.usage["prompt_tokens"] - model_name = instance.model or kwargs.get("model") + model_name = _resolve_model_name(response.model, kwargs.get("model"), instance.model) + if model_name is None: + return response + impacts = llm_impacts( provider=PROVIDER, model_name=model_name, @@ -101,13 +120,18 @@ def huggingface_chat_wrapper_stream( encoder = tiktoken.get_encoding("cl100k_base") prompt_text = "".join([m["content"] for m in kwargs["messages"]]) input_tokens = len(encoder.encode(prompt_text)) - model_name = instance.model or kwargs.get("model") + default_model_name = _resolve_model_name(kwargs.get("model"), instance.model) timer_start = time.perf_counter() stream = wrapped(*args, **kwargs) output_tokens = 0 for chunk in stream: - output_tokens += 1 # noqa: SIM113 + output_tokens += 1 request_latency = time.perf_counter() - timer_start + model_name = _resolve_model_name(getattr(chunk, "model", None), default_model_name) + if model_name is None: + yield chunk + continue + impacts = llm_impacts( provider=PROVIDER, model_name=model_name, @@ -169,7 +193,10 @@ async def huggingface_async_chat_wrapper_non_stream( request_latency = time.perf_counter() - timer_start output_tokens = response.usage["completion_tokens"] input_tokens = response.usage["prompt_tokens"] - model_name = instance.model or kwargs.get("model") + model_name = _resolve_model_name(response.model, kwargs.get("model"), instance.model) + if model_name is None: + return response + impacts = llm_impacts( provider=PROVIDER, model_name=model_name, @@ -203,13 +230,18 @@ async def huggingface_async_chat_wrapper_stream( encoder = tiktoken.get_encoding("cl100k_base") prompt_text = "".join([m["content"] for m in kwargs["messages"]]) input_tokens = len(encoder.encode(prompt_text)) - model_name = instance.model or kwargs.get("model") + default_model_name = _resolve_model_name(kwargs.get("model"), instance.model) timer_start = time.perf_counter() stream = await wrapped(*args, **kwargs) output_tokens = 0 async for chunk in stream: output_tokens += 1 request_latency = time.perf_counter() - timer_start + model_name = _resolve_model_name(getattr(chunk, "model", None), default_model_name) + if model_name is None: + yield chunk + continue + impacts = llm_impacts( provider=PROVIDER, model_name=model_name, diff --git a/tests/conftest.py b/tests/conftest.py index 44a98df3..d27791c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ def environment(): set_envvar_if_unset("OPENAI_API_KEY", "test-api-key") set_envvar_if_unset("CO_API_KEY", "test-api-key") set_envvar_if_unset("GOOGLE_API_KEY", "test-api-key") + set_envvar_if_unset("HF_TOKEN", "hf_test-token") set_envvar_if_unset("AZURE_OPENAI_API_KEY", "test-api-key") set_envvar_if_unset("AZURE_OPENAI_ENDPOINT", "https://ecologits-azure-openai.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview") set_envvar_if_unset("OPENAI_API_VERSION", "2024-06-01") diff --git a/tests/test_huggingface_hub.py b/tests/test_huggingface_hub.py index 2de04c9b..c22c32fe 100644 --- a/tests/test_huggingface_hub.py +++ b/tests/test_huggingface_hub.py @@ -1,10 +1,15 @@ import pytest from huggingface_hub import AsyncInferenceClient, InferenceClient +HF_INFERENCE_MODEL_URL = ( + "https://api-inference.huggingface.co/models/" + "meta-llama/Meta-Llama-3-8B-Instruct" +) + @pytest.mark.vcr def test_huggingface_hub_chat(tracer_init): - client = InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct") + client = InferenceClient(model=HF_INFERENCE_MODEL_URL) response = client.chat_completion( messages=[{"role": "user", "content": "Hello World!"}], max_tokens=15 @@ -16,7 +21,7 @@ def test_huggingface_hub_chat(tracer_init): @pytest.mark.vcr @pytest.mark.asyncio async def test_huggingface_hub_async_chat(tracer_init): - client = AsyncInferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct") + client = AsyncInferenceClient(model=HF_INFERENCE_MODEL_URL) response = await client.chat_completion( messages=[{"role": "user", "content": "Hello World!"}], max_tokens=15 @@ -27,7 +32,7 @@ async def test_huggingface_hub_async_chat(tracer_init): @pytest.mark.vcr def test_huggingface_hub_stream_chat(tracer_init): - client = InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct") + client = InferenceClient(model=HF_INFERENCE_MODEL_URL) stream = client.chat_completion( messages=[{"role": "user", "content": "Hello World!"}], max_tokens=15, @@ -40,7 +45,7 @@ def test_huggingface_hub_stream_chat(tracer_init): @pytest.mark.vcr @pytest.mark.asyncio async def test_huggingface_hub_async_stream_chat(tracer_init): - client = AsyncInferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct") + client = AsyncInferenceClient(model=HF_INFERENCE_MODEL_URL) stream = await client.chat_completion( messages=[{"role": "user", "content": "Hello World!"}], max_tokens=15,