Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import os

from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from eval_protocol.tracing import PayloadType, decode_payloads
from .base import BaseAdapter
from .lp_deserializer import decompress_and_parse_lp
from .r3_deserializer import decompress_and_parse_r3
from .utils import extract_messages_from_data
from ..common_utils import get_user_agent

Expand Down Expand Up @@ -102,45 +101,42 @@
):
break # Break early if we've found all the metadata we need

# Extract router replay payloads when present
# Decoding lives in eval_protocol.tracing; here we only map results onto the row.
payloads = trace.get("payloads")
if isinstance(payloads, dict):
router_replay = payloads.get("router_replay")
if isinstance(router_replay, dict) and router_replay.get("data"):
try:
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["routing_matrices"] = matrices
execution_metadata.extra["routing_metadata"] = r3_meta
except Exception as e:
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)

logprobs_payload = payloads.get("logprobs")
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
try:
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["completion_logprobs"] = logprobs
if token_ids is not None:
execution_metadata.extra["completion_token_ids"] = token_ids
execution_metadata.extra["logprobs_metadata"] = lp_meta

for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "assistant":
content_entries = [{"logprob": lp} for lp in logprobs]
if token_ids is not None:
for entry, tid in zip(content_entries, token_ids):
entry["token_id"] = tid
messages[i].logprobs = {"content": content_entries}
break
except Exception as e:
logger.warning(
"Failed to decompress logprobs payload for trace %s: %s",
trace.get("id"),
e,
)
decoded = decode_payloads(
payloads,
on_error=lambda pt, e: logger.warning(
"Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e
),
)
if decoded and execution_metadata.extra is None:
execution_metadata.extra = {}

if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None:
execution_metadata.extra["routing_matrices"] = dp.value

Check failure on line 117 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["routing_metadata"] = dp.metadata

Check failure on line 118 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

if (dp := decoded.get(PayloadType.LOGPROBS)) is not None:
logprobs = dp.value
token_ids = dp.token_ids
execution_metadata.extra["completion_logprobs"] = logprobs

Check failure on line 123 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
if token_ids is not None:
execution_metadata.extra["completion_token_ids"] = token_ids

Check failure on line 125 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["logprobs_metadata"] = dp.metadata

Check failure on line 126 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "assistant":
content_entries = [{"logprob": lp} for lp in logprobs]
if token_ids is not None:
for entry, tid in zip(content_entries, token_ids):
entry["token_id"] = tid
messages[i].logprobs = {"content": content_entries}
break

if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None:
execution_metadata.extra["prompt_token_ids"] = dp.value

Check failure on line 138 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata

Check failure on line 139 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

return EvaluationRow(
messages=messages,
Expand Down
140 changes: 36 additions & 104 deletions eval_protocol/adapters/lp_deserializer.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,41 @@
"""LP/v1 binary deserializer for per-token logprobs payloads.
"""Deprecated compatibility shim for ``eval_protocol.tracing.logprobs``.

Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``.
See that module for the full header specification.
Import from ``eval_protocol.tracing.logprobs`` (or ``decode_payloads`` from
``eval_protocol.tracing``) instead. This module re-exports the LP/v1 helpers
that lived here before the tracing package refactor.
"""

from __future__ import annotations

import base64
import struct
from typing import Any, Dict, List, Optional, Tuple

import zstandard as zstd

MAGIC = b"LP01"
HEADER_VERSION = 1
MISSING_TOKEN_ID = -1
ENTRY_FORMAT = "<if"
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes
HEADER_FORMAT = "<4sBBHIIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes


def _parse_header(raw: bytes) -> Dict[str, Any]:
if len(raw) < HEADER_SIZE:
raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}")

(
magic,
version,
flags,
reserved_u16,
token_count,
body_byte_length,
reserved_u64,
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])

if magic != MAGIC:
raise ValueError(f"Bad LP/v1 magic: {magic!r}")
if version != HEADER_VERSION:
raise ValueError(f"Unsupported lp/v1 header version: {version}")

return {
"flags": flags,
"reserved_u16": reserved_u16,
"token_count": token_count,
"body_byte_length": body_byte_length,
"reserved_u64": reserved_u64,
}


def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
"""Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata."""
header = _parse_header(raw)
token_count = header["token_count"]
body_byte_length = header["body_byte_length"]

if token_count == 0:
raise ValueError("LP/v1 token_count must be > 0")
if body_byte_length != token_count * ENTRY_SIZE:
raise ValueError(
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
f"({token_count * ENTRY_SIZE})"
)

expected_len = HEADER_SIZE + body_byte_length
if len(raw) != expected_len:
raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}")

logprobs: List[float] = []
token_ids: List[int] = []
all_token_ids_valid = True
offset = HEADER_SIZE
for _ in range(token_count):
wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
offset += ENTRY_SIZE
logprobs.append(logprob)
if wire_id == MISSING_TOKEN_ID:
all_token_ids_valid = False
token_ids.append(wire_id)
else:
token_ids.append(wire_id)

metadata: Dict[str, Any] = {
"scope": "completion_only",
"completion_token_count": token_count,
"all_token_ids_valid": all_token_ids_valid,
}
header.update(metadata)
ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None
return logprobs, ids_out, header


def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
"""Decompress and unpack an LP/v1 payload into completion logprobs and token ids.

Args:
data_b64: Base64-encoded zstd-compressed LP binary blob from
``payloads.logprobs.data``.

Returns:
``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token
scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``,
and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``.
"""
compressed = base64.b64decode(data_b64)
decompressor = zstd.ZstdDecompressor()
raw = decompressor.decompress(compressed)
return parse_logprobs(raw)
import warnings

warnings.warn(
"eval_protocol.adapters.lp_deserializer is deprecated; "
"import from eval_protocol.tracing.logprobs instead.",
DeprecationWarning,
stacklevel=2,
)

from eval_protocol.tracing.logprobs import ( # noqa: E402
ENTRY_FORMAT,
ENTRY_SIZE,
HEADER_FORMAT,
HEADER_SIZE,
HEADER_VERSION,
MAGIC,
MISSING_TOKEN_ID,
decompress_and_parse_lp,
parse_logprobs,
)

__all__ = [
"ENTRY_FORMAT",
"ENTRY_SIZE",
"HEADER_FORMAT",
"HEADER_SIZE",
"HEADER_VERSION",
"MAGIC",
"MISSING_TOKEN_ID",
"decompress_and_parse_lp",
"parse_logprobs",
]
Loading
Loading