Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions ecologits/tracers/anthropic_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ class MessageStreamManager(Generic[MessageStreamT]):
Re-writing of Anthropic's `MessageStreamManager` with wrapped `MessageStream`
"""

def __init__(self, api_request: Callable[[], MessageStream]) -> None:
def __init__(self, api_request: Callable[[], Any], output_format: Any) -> None:
self.__api_request = api_request
self.__output_format = output_format

def __enter__(self) -> MessageStream:
self.__stream = self.__api_request()
self.__stream = MessageStream(self.__stream)
raw_stream = self.__api_request()
self.__stream = MessageStream(raw_stream, output_format=self.__output_format)
return self.__stream

def __exit__(
Expand All @@ -149,12 +150,13 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]):
"""
Re-writing of Anthropic's `AsyncMessageStreamManager` with wrapped `AsyncMessageStream`
"""
def __init__(self, api_request: Awaitable[AsyncMessageStream]) -> None:
def __init__(self, api_request: Awaitable[Any], output_format: Any) -> None:
self.__api_request = api_request
self.__output_format = output_format

async def __aenter__(self) -> AsyncMessageStream:
self.__stream = await self.__api_request
self.__stream = AsyncMessageStream(self.__stream)
raw_stream = await self.__api_request
self.__stream = AsyncMessageStream(raw_stream, output_format=self.__output_format)
return self.__stream

async def __aexit__(
Expand Down Expand Up @@ -271,7 +273,9 @@ def anthropic_stream_chat_wrapper(
A wrapped `MessageStreamManager` with impacts
"""
response = wrapped(*args, **kwargs)
return MessageStreamManager(response._MessageStreamManager__api_request) # noqa: SLF001
api_request = response._MessageStreamManager__api_request # noqa: SLF001
output_format = response._MessageStreamManager__output_format # noqa: SLF001
return MessageStreamManager(api_request, output_format)


def anthropic_async_stream_chat_wrapper(
Expand All @@ -290,7 +294,9 @@ def anthropic_async_stream_chat_wrapper(
A wrapped `AsyncMessageStreamManager` with impacts
"""
response = wrapped(*args, **kwargs)
return AsyncMessageStreamManager(response._AsyncMessageStreamManager__api_request) # noqa: SLF001
api_request = response._AsyncMessageStreamManager__api_request # noqa: SLF001
output_format = response._AsyncMessageStreamManager__output_format # noqa: SLF001
return AsyncMessageStreamManager(api_request, output_format)


class AnthropicInstrumentor:
Expand Down
42 changes: 37 additions & 5 deletions ecologits/tracers/huggingface_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from ecologits.tracers.utils import ImpactsOutput, llm_impacts

PROVIDER = "huggingface_hub"
HF_INFERENCE_URL_PREFIXES = (
"https://api-inference.huggingface.co/models/",
"https://router.huggingface.co/hf-inference/models/",
)


@dataclass
Expand All @@ -31,6 +35,18 @@ class ChatCompletionStreamOutput(_ChatCompletionStreamOutput):
impacts: Optional[ImpactsOutput] = None


def _resolve_model_name(*values: Optional[str]) -> Optional[str]:
for value in values:
if value is None:
continue
for prefix in HF_INFERENCE_URL_PREFIXES:
if value.startswith(prefix):
return value.removeprefix(prefix).removesuffix("/v1/chat/completions")
return value

return None


def huggingface_chat_wrapper(
wrapped: Callable,
instance: InferenceClient,
Expand Down Expand Up @@ -67,7 +83,10 @@ def huggingface_chat_wrapper_non_stream(
request_latency = time.perf_counter() - timer_start
output_tokens = response.usage["completion_tokens"]
input_tokens = response.usage["prompt_tokens"]
model_name = instance.model or kwargs.get("model")
model_name = _resolve_model_name(response.model, kwargs.get("model"), instance.model)
if model_name is None:
return response

impacts = llm_impacts(
provider=PROVIDER,
model_name=model_name,
Expand Down Expand Up @@ -101,13 +120,18 @@ def huggingface_chat_wrapper_stream(
encoder = tiktoken.get_encoding("cl100k_base")
prompt_text = "".join([m["content"] for m in kwargs["messages"]])
input_tokens = len(encoder.encode(prompt_text))
model_name = instance.model or kwargs.get("model")
default_model_name = _resolve_model_name(kwargs.get("model"), instance.model)
timer_start = time.perf_counter()
stream = wrapped(*args, **kwargs)
output_tokens = 0
for chunk in stream:
output_tokens += 1 # noqa: SIM113
output_tokens += 1
request_latency = time.perf_counter() - timer_start
model_name = _resolve_model_name(getattr(chunk, "model", None), default_model_name)
if model_name is None:
yield chunk
continue

impacts = llm_impacts(
provider=PROVIDER,
model_name=model_name,
Expand Down Expand Up @@ -169,7 +193,10 @@ async def huggingface_async_chat_wrapper_non_stream(
request_latency = time.perf_counter() - timer_start
output_tokens = response.usage["completion_tokens"]
input_tokens = response.usage["prompt_tokens"]
model_name = instance.model or kwargs.get("model")
model_name = _resolve_model_name(response.model, kwargs.get("model"), instance.model)
if model_name is None:
return response

impacts = llm_impacts(
provider=PROVIDER,
model_name=model_name,
Expand Down Expand Up @@ -203,13 +230,18 @@ async def huggingface_async_chat_wrapper_stream(
encoder = tiktoken.get_encoding("cl100k_base")
prompt_text = "".join([m["content"] for m in kwargs["messages"]])
input_tokens = len(encoder.encode(prompt_text))
model_name = instance.model or kwargs.get("model")
default_model_name = _resolve_model_name(kwargs.get("model"), instance.model)
timer_start = time.perf_counter()
stream = await wrapped(*args, **kwargs)
output_tokens = 0
async for chunk in stream:
output_tokens += 1
request_latency = time.perf_counter() - timer_start
model_name = _resolve_model_name(getattr(chunk, "model", None), default_model_name)
if model_name is None:
yield chunk
continue

impacts = llm_impacts(
provider=PROVIDER,
model_name=model_name,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def environment():
set_envvar_if_unset("OPENAI_API_KEY", "test-api-key")
set_envvar_if_unset("CO_API_KEY", "test-api-key")
set_envvar_if_unset("GOOGLE_API_KEY", "test-api-key")
set_envvar_if_unset("HF_TOKEN", "hf_test-token")
set_envvar_if_unset("AZURE_OPENAI_API_KEY", "test-api-key")
set_envvar_if_unset("AZURE_OPENAI_ENDPOINT", "https://ecologits-azure-openai.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview")
set_envvar_if_unset("OPENAI_API_VERSION", "2024-06-01")
Expand Down
13 changes: 9 additions & 4 deletions tests/test_huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import pytest
from huggingface_hub import AsyncInferenceClient, InferenceClient

HF_INFERENCE_MODEL_URL = (
"https://api-inference.huggingface.co/models/"
"meta-llama/Meta-Llama-3-8B-Instruct"
)


@pytest.mark.vcr
def test_huggingface_hub_chat(tracer_init):
client = InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct")
client = InferenceClient(model=HF_INFERENCE_MODEL_URL)
response = client.chat_completion(
messages=[{"role": "user", "content": "Hello World!"}],
max_tokens=15
Expand All @@ -16,7 +21,7 @@ def test_huggingface_hub_chat(tracer_init):
@pytest.mark.vcr
@pytest.mark.asyncio
async def test_huggingface_hub_async_chat(tracer_init):
client = AsyncInferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct")
client = AsyncInferenceClient(model=HF_INFERENCE_MODEL_URL)
response = await client.chat_completion(
messages=[{"role": "user", "content": "Hello World!"}],
max_tokens=15
Expand All @@ -27,7 +32,7 @@ async def test_huggingface_hub_async_chat(tracer_init):

@pytest.mark.vcr
def test_huggingface_hub_stream_chat(tracer_init):
client = InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct")
client = InferenceClient(model=HF_INFERENCE_MODEL_URL)
stream = client.chat_completion(
messages=[{"role": "user", "content": "Hello World!"}],
max_tokens=15,
Expand All @@ -40,7 +45,7 @@ def test_huggingface_hub_stream_chat(tracer_init):
@pytest.mark.vcr
@pytest.mark.asyncio
async def test_huggingface_hub_async_stream_chat(tracer_init):
client = AsyncInferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct")
client = AsyncInferenceClient(model=HF_INFERENCE_MODEL_URL)
stream = await client.chat_completion(
messages=[{"role": "user", "content": "Hello World!"}],
max_tokens=15,
Expand Down
Loading