diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 62a632e6..0ff86f71 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -15,9 +15,8 @@ import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from eval_protocol.tracing import PayloadType, decode_payloads from .base import BaseAdapter -from .lp_deserializer import decompress_and_parse_lp -from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -102,45 +101,42 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need - # Extract router replay payloads when present + # Decoding lives in eval_protocol.tracing; here we only map results onto the row. payloads = trace.get("payloads") if isinstance(payloads, dict): - router_replay = payloads.get("router_replay") - if isinstance(router_replay, dict) and router_replay.get("data"): - try: - matrices, r3_meta = decompress_and_parse_r3(router_replay["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["routing_matrices"] = matrices - execution_metadata.extra["routing_metadata"] = r3_meta - except Exception as e: - logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e) - - logprobs_payload = payloads.get("logprobs") - if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"): - try: - logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"]) - if execution_metadata.extra is None: - execution_metadata.extra = {} - execution_metadata.extra["completion_logprobs"] = logprobs - if token_ids is not None: - execution_metadata.extra["completion_token_ids"] = token_ids - execution_metadata.extra["logprobs_metadata"] = lp_meta - - for i in range(len(messages) - 1, -1, -1): - if messages[i].role == "assistant": - content_entries = [{"logprob": lp} for lp in logprobs] - if token_ids is not None: - for entry, tid in zip(content_entries, token_ids): - entry["token_id"] = tid - messages[i].logprobs = {"content": content_entries} - break - except Exception as e: - logger.warning( - "Failed to decompress logprobs payload for trace %s: %s", - trace.get("id"), - e, - ) + decoded = decode_payloads( + payloads, + on_error=lambda pt, e: logger.warning( + "Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e + ), + ) + if decoded and execution_metadata.extra is None: + execution_metadata.extra = {} + + if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None: + execution_metadata.extra["routing_matrices"] = dp.value + execution_metadata.extra["routing_metadata"] = dp.metadata + + if (dp := decoded.get(PayloadType.LOGPROBS)) is not None: + logprobs = dp.value + token_ids = dp.token_ids + execution_metadata.extra["completion_logprobs"] = logprobs + if token_ids is not None: + execution_metadata.extra["completion_token_ids"] = token_ids + execution_metadata.extra["logprobs_metadata"] = dp.metadata + + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "assistant": + content_entries = [{"logprob": lp} for lp in logprobs] + if token_ids is not None: + for entry, tid in zip(content_entries, token_ids): + entry["token_id"] = tid + messages[i].logprobs = {"content": content_entries} + break + + if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None: + execution_metadata.extra["prompt_token_ids"] = dp.value + execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata return EvaluationRow( messages=messages, diff --git a/eval_protocol/adapters/lp_deserializer.py b/eval_protocol/adapters/lp_deserializer.py index 57aa4f46..ffd79ff5 100644 --- a/eval_protocol/adapters/lp_deserializer.py +++ b/eval_protocol/adapters/lp_deserializer.py @@ -1,109 +1,41 @@ -"""LP/v1 binary deserializer for per-token logprobs payloads. +"""Deprecated compatibility shim for ``eval_protocol.tracing.logprobs``. -Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``. -See that module for the full header specification. +Import from ``eval_protocol.tracing.logprobs`` (or ``decode_payloads`` from +``eval_protocol.tracing``) instead. This module re-exports the LP/v1 helpers +that lived here before the tracing package refactor. """ from __future__ import annotations -import base64 -import struct -from typing import Any, Dict, List, Optional, Tuple - -import zstandard as zstd - -MAGIC = b"LP01" -HEADER_VERSION = 1 -MISSING_TOKEN_ID = -1 -ENTRY_FORMAT = " Dict[str, Any]: - if len(raw) < HEADER_SIZE: - raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}") - - ( - magic, - version, - flags, - reserved_u16, - token_count, - body_byte_length, - reserved_u64, - ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) - - if magic != MAGIC: - raise ValueError(f"Bad LP/v1 magic: {magic!r}") - if version != HEADER_VERSION: - raise ValueError(f"Unsupported lp/v1 header version: {version}") - - return { - "flags": flags, - "reserved_u16": reserved_u16, - "token_count": token_count, - "body_byte_length": body_byte_length, - "reserved_u64": reserved_u64, - } - - -def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: - """Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata.""" - header = _parse_header(raw) - token_count = header["token_count"] - body_byte_length = header["body_byte_length"] - - if token_count == 0: - raise ValueError("LP/v1 token_count must be > 0") - if body_byte_length != token_count * ENTRY_SIZE: - raise ValueError( - f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " - f"({token_count * ENTRY_SIZE})" - ) - - expected_len = HEADER_SIZE + body_byte_length - if len(raw) != expected_len: - raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}") - - logprobs: List[float] = [] - token_ids: List[int] = [] - all_token_ids_valid = True - offset = HEADER_SIZE - for _ in range(token_count): - wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) - offset += ENTRY_SIZE - logprobs.append(logprob) - if wire_id == MISSING_TOKEN_ID: - all_token_ids_valid = False - token_ids.append(wire_id) - else: - token_ids.append(wire_id) - - metadata: Dict[str, Any] = { - "scope": "completion_only", - "completion_token_count": token_count, - "all_token_ids_valid": all_token_ids_valid, - } - header.update(metadata) - ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None - return logprobs, ids_out, header - - -def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: - """Decompress and unpack an LP/v1 payload into completion logprobs and token ids. - - Args: - data_b64: Base64-encoded zstd-compressed LP binary blob from - ``payloads.logprobs.data``. - - Returns: - ``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token - scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``, - and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``. - """ - compressed = base64.b64decode(data_b64) - decompressor = zstd.ZstdDecompressor() - raw = decompressor.decompress(compressed) - return parse_logprobs(raw) +import warnings + +warnings.warn( + "eval_protocol.adapters.lp_deserializer is deprecated; " + "import from eval_protocol.tracing.logprobs instead.", + DeprecationWarning, + stacklevel=2, +) + +from eval_protocol.tracing.logprobs import ( # noqa: E402 + ENTRY_FORMAT, + ENTRY_SIZE, + HEADER_FORMAT, + HEADER_SIZE, + HEADER_VERSION, + MAGIC, + MISSING_TOKEN_ID, + decompress_and_parse_lp, + parse_logprobs, +) + +__all__ = [ + "ENTRY_FORMAT", + "ENTRY_SIZE", + "HEADER_FORMAT", + "HEADER_SIZE", + "HEADER_VERSION", + "MAGIC", + "MISSING_TOKEN_ID", + "decompress_and_parse_lp", + "parse_logprobs", +] diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 1e3c1a8c..c2a0bd0f 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -1,187 +1,41 @@ -"""R3/v1 binary deserializer for router-replay payloads. +"""Deprecated compatibility shim for ``eval_protocol.tracing.router_replay``. -Implements the inverse of the packed binary format produced by the tracing -gateway's ``r3_serializer.serialize_r3``. See that module for the full -header specification. - -The main entry point is :func:`decompress_and_parse_r3`, which accepts the -base64-encoded compressed blob returned by the gateway's -``/v1/traces/pointwise?include_payloads=true`` endpoint and produces -per-token routing matrices in the same ``List[Optional[str]]`` format used -by the direct inference path (``DeploymentSampler.sample_with_tokens()``). +Import from ``eval_protocol.tracing.router_replay`` (or ``decode_payloads`` from +``eval_protocol.tracing``) instead. This module re-exports the R3/v1 helpers +that lived here before the tracing package refactor. """ from __future__ import annotations -import base64 -import struct -from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple - -import zstandard as zstd - -MAGIC = b"R3V1" -HEADER_FORMAT = "<4sBBBBIIIIQ" -HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes -BITS_PER_BYTE = 8 - - -class _SelectorMode(IntEnum): - ALL = 0 - SUFFIX = 1 - BITMAP = 2 - - -class _RoutingDtype(IntEnum): - UINT8 = 1 - UINT16 = 2 - - -_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} -_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} - - -def _parse_header(raw: bytes) -> Dict[str, Any]: - if len(raw) < HEADER_SIZE: - raise ValueError( - f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}" - ) - - ( - magic, - version, - selector_mode, - routing_dtype, - flags, - total_token_count, - replayed_token_count, - replay_start_token, - selector_byte_length, - matrix_byte_length, - ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) - - if magic != MAGIC: - raise ValueError(f"Bad R3 magic: {magic!r}") - if version != 1: - raise ValueError(f"Unsupported R3 header version: {version}") - - return { - "selector_mode": selector_mode, - "routing_dtype": routing_dtype, - "flags": flags, - "total_token_count": total_token_count, - "replayed_token_count": replayed_token_count, - "replay_start_token": replay_start_token, - "selector_byte_length": selector_byte_length, - "matrix_byte_length": matrix_byte_length, - } - - -def _read_bitmap_positions( - selector_bytes: bytes, total_token_count: int -) -> List[int]: - """Return sorted token indices where the bitmap bit is set.""" - positions: List[int] = [] - for i in range(total_token_count): - byte_idx = i // BITS_PER_BYTE - bit_idx = i % BITS_PER_BYTE - if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: - positions.append(i) - return positions - - -def decompress_and_parse_r3( - data_b64: str, -) -> Tuple[List[Optional[str]], Dict[str, Any]]: - """Decompress and unpack an R3/v1 payload into per-token routing matrices. - - Args: - data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned - by the tracing gateway in ``payloads.router_replay.data``. - - Returns: - A tuple of ``(routing_matrices, metadata)`` where: - - - ``routing_matrices`` is a ``List[Optional[str]]`` of length - ``total_token_count``. Each present position contains a - base64-encoded routing matrix (matching the format returned by - the direct inference path); absent positions are ``None``. - - ``metadata`` is a dict with keys ``routing_dtype``, - ``selector_mode``, ``total_token_count``, ``replayed_token_count``, - ``replay_start_token``. - """ - compressed = base64.b64decode(data_b64) - - # ZstdCompressor.compress() embeds the uncompressed size in the frame - # header by default, so the library can auto-allocate the output buffer. - decompressor = zstd.ZstdDecompressor() - raw = decompressor.decompress(compressed) - - header = _parse_header(raw) - - selector_mode = header["selector_mode"] - routing_dtype = header["routing_dtype"] - total_token_count = header["total_token_count"] - replayed_token_count = header["replayed_token_count"] - replay_start_token = header["replay_start_token"] - selector_byte_length = header["selector_byte_length"] - matrix_byte_length = header["matrix_byte_length"] - - metadata: Dict[str, Any] = { - "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), - "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), - "total_token_count": total_token_count, - "replayed_token_count": replayed_token_count, - "replay_start_token": replay_start_token, - } - - if replayed_token_count == 0: - return [None] * total_token_count, metadata - - # Per-token matrix byte size is implicit in the payload: all replayed - # tokens share the same matrix length, so we can recover it from the - # matrix section total length divided by the replayed-token count. - if matrix_byte_length % replayed_token_count != 0: - raise ValueError( - f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " - f"replayed_token_count ({replayed_token_count}); cannot split " - "into per-token matrices" - ) - matrix_elem_size = matrix_byte_length // replayed_token_count - - body = raw[HEADER_SIZE:] - expected_body_length = selector_byte_length + matrix_byte_length - if len(body) < expected_body_length: - raise ValueError( - f"Payload body too short for selector and matrix sections: " - f"{len(body)} < {expected_body_length}" - ) - - selector_bytes = body[:selector_byte_length] - matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] - - if selector_mode == _SelectorMode.ALL: - replayed_positions = list(range(total_token_count)) - elif selector_mode == _SelectorMode.SUFFIX: - replayed_positions = list( - range(replay_start_token, replay_start_token + replayed_token_count) - ) - elif selector_mode == _SelectorMode.BITMAP: - replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count) - else: - raise ValueError(f"Unknown selector_mode: {selector_mode}") - - if len(replayed_positions) != replayed_token_count: - raise ValueError( - f"Selector produced {len(replayed_positions)} replayed positions, " - f"but header replayed_token_count is {replayed_token_count}" - ) - - # Split matrix bytes into per-token chunks and base64-encode each one - matrices: List[Optional[str]] = [None] * total_token_count - for idx, pos in enumerate(replayed_positions): - start = idx * matrix_elem_size - end = start + matrix_elem_size - matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") - - return matrices, metadata +import warnings + +warnings.warn( + "eval_protocol.adapters.r3_deserializer is deprecated; " + "import from eval_protocol.tracing.router_replay instead.", + DeprecationWarning, + stacklevel=2, +) + +from eval_protocol.tracing.router_replay import ( # noqa: E402 + BITS_PER_BYTE, + HEADER_FORMAT, + HEADER_SIZE, + MAGIC, + _RoutingDtype, + _SelectorMode, + _parse_header, + _read_bitmap_positions, + decompress_and_parse_r3, +) + +__all__ = [ + "BITS_PER_BYTE", + "HEADER_FORMAT", + "HEADER_SIZE", + "MAGIC", + "_RoutingDtype", + "_SelectorMode", + "_parse_header", + "_read_bitmap_positions", + "decompress_and_parse_r3", +] diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 279d1055..e0d5db73 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval for key in ( "completion_logprobs", "completion_token_ids", + "prompt_token_ids", + "prompt_token_ids_metadata", "logprobs_metadata", "routing_matrices", "routing_metadata", diff --git a/eval_protocol/tracing/README.md b/eval_protocol/tracing/README.md new file mode 100644 index 00000000..10a45109 --- /dev/null +++ b/eval_protocol/tracing/README.md @@ -0,0 +1,109 @@ +# `eval_protocol.tracing` — Fireworks tracing-gateway payload decoders + +Standalone helpers for decoding the out-of-band **payloads** the Fireworks +tracing gateway stores alongside a trace (prompt token IDs, completion logprobs, +router-replay routing matrices). + +This package is intentionally self-contained: it depends only on the stdlib and +`zstandard`. It does **not** import `EvaluationRow`, rollout processors, or any +other Eval Protocol machinery, so you can use it even if you are not using EP for +rollouts — just point at it for extracting gateway payloads. + +## What is a "payload"? + +When you read a trace with payloads included: + +``` +GET {gateway}/v1/traces?rollout_id=...&include_payloads=true +``` + +each trace carries a `payloads` object like: + +```json +{ + "payloads": { + "prompt_token_ids": { + "manifest": { "PayloadVersion": "pti/v1", "...": "..." }, + "data": "" + }, + "logprobs": { "manifest": { "PayloadVersion": "lp/v1" }, "data": "..." }, + "router_replay": { "manifest": { "PayloadVersion": "r3/v1" }, "data": "..." } + } +} +``` + +The `data` field is `base64(zstd(raw_bytes))`. Each payload type has its own +`raw_bytes` encoding (`pti/v1` is a JSON int array; `lp/v1` and `r3/v1` are packed +binary). This package hides all of that. + +## Usage + +Decode everything at once (the common case): + +```python +from eval_protocol.tracing import decode_payloads, PayloadType + +decoded = decode_payloads(trace["payloads"]) + +if PayloadType.PROMPT_TOKEN_IDS in decoded: + token_ids = decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + +if PayloadType.LOGPROBS in decoded: + lp = decoded[PayloadType.LOGPROBS] + logprobs = lp.value # List[float] + token_ids = lp.token_ids # Optional[List[int]] + +if PayloadType.ROUTER_REPLAY in decoded: + matrices = decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] +``` + +If you have the whole trace dict, `decode_trace(trace)` reaches into +`trace["payloads"]` for you. + +Decode a single payload: + +```python +from eval_protocol.tracing import decode_payload, PayloadType + +dp = decode_payload(PayloadType.PROMPT_TOKEN_IDS, trace["payloads"]["prompt_token_ids"]["data"]) +dp.value # List[int] +``` + +### Error handling + +`decode_payloads` isolates per-payload failures: if one payload fails to decode, +the others are still returned. Pass `on_error=callback(payload_type, exc)` to +control logging (defaults to a warning): + +```python +decode_payloads(payloads, on_error=lambda pt, e: print(f"{pt} failed: {e}")) +``` + +## Return type + +`decode_payloads` / `decode_trace` return `Dict[PayloadType, DecodedPayload]`. + +`DecodedPayload` fields: + +| field | meaning | +|----------------|-------------------------------------------------------------------| +| `payload_type` | `PayloadType` enum member | +| `value` | decoded value (type depends on `payload_type`, see below) | +| `metadata` | decoded header/manifest metadata (token counts, scope, etc.) | +| `token_ids` | `Optional[List[int]]` — LOGPROBS per-token ids (else `None`) | + +`value` by type: + +| `PayloadType` | `value` | notes | +|---------------------|--------------------------|----------------------------------------------| +| `PROMPT_TOKEN_IDS` | `List[int]` | prompt token ids | +| `LOGPROBS` | `List[float]` | per completion token; ids in `token_ids` (or `None`) | +| `ROUTER_REPLAY` | `List[Optional[str]]` | per-token base64 routing matrices; `None` where absent | + +## Adding a new payload type + +1. Add a member to `PayloadType` in `types.py`. +2. Add a `decode_(data_b64) -> DecodedPayload` function in a new module. +3. Register it in `PAYLOAD_DECODERS` in `registry.py`. + +`decode_payloads` picks it up automatically. diff --git a/eval_protocol/tracing/__init__.py b/eval_protocol/tracing/__init__.py new file mode 100644 index 00000000..7ad9d436 --- /dev/null +++ b/eval_protocol/tracing/__init__.py @@ -0,0 +1,31 @@ +"""Decode Fireworks tracing-gateway payloads. + +Standalone, dependency-light helpers (stdlib + ``zstandard`` only) for turning +the binary/JSON ``payloads`` returned by the Fireworks tracing gateway +(``GET /traces?include_payloads=true``) into Python values. No EvaluationRow or +rollout machinery required -- usable on its own. + +Typical use:: + + from eval_protocol.tracing import decode_payloads, PayloadType + + decoded = decode_payloads(trace["payloads"]) + decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int] + decoded[PayloadType.LOGPROBS].value # List[float] + decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]] + +See ``README.md`` in this package for details. +""" + +from __future__ import annotations + +from .registry import decode_payload, decode_payloads, decode_trace +from .types import DecodedPayload, PayloadType + +__all__ = [ + "PayloadType", + "DecodedPayload", + "decode_payloads", + "decode_payload", + "decode_trace", +] diff --git a/eval_protocol/tracing/logprobs.py b/eval_protocol/tracing/logprobs.py new file mode 100644 index 00000000..718f24e5 --- /dev/null +++ b/eval_protocol/tracing/logprobs.py @@ -0,0 +1,126 @@ +"""LP/v1 binary deserializer for per-token logprobs payloads. + +Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``. +See that module for the full header specification. +""" + +from __future__ import annotations + +import base64 +import struct +from typing import Any, Dict, List, Optional, Tuple + +import zstandard as zstd + +from .types import DecodedPayload, PayloadType + +MAGIC = b"LP01" +HEADER_VERSION = 1 +MISSING_TOKEN_ID = -1 +ENTRY_FORMAT = " Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}") + + ( + magic, + version, + flags, + reserved_u16, + token_count, + body_byte_length, + reserved_u64, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad LP/v1 magic: {magic!r}") + if version != HEADER_VERSION: + raise ValueError(f"Unsupported lp/v1 header version: {version}") + + return { + "flags": flags, + "reserved_u16": reserved_u16, + "token_count": token_count, + "body_byte_length": body_byte_length, + "reserved_u64": reserved_u64, + } + + +def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: + """Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata.""" + header = _parse_header(raw) + token_count = header["token_count"] + body_byte_length = header["body_byte_length"] + + if token_count == 0: + raise ValueError("LP/v1 token_count must be > 0") + if body_byte_length != token_count * ENTRY_SIZE: + raise ValueError( + f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " + f"({token_count * ENTRY_SIZE})" + ) + + expected_len = HEADER_SIZE + body_byte_length + if len(raw) != expected_len: + raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}") + + logprobs: List[float] = [] + token_ids: List[int] = [] + all_token_ids_valid = True + offset = HEADER_SIZE + for _ in range(token_count): + wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) + offset += ENTRY_SIZE + logprobs.append(logprob) + if wire_id == MISSING_TOKEN_ID: + all_token_ids_valid = False + token_ids.append(wire_id) + else: + token_ids.append(wire_id) + + metadata: Dict[str, Any] = { + "scope": "completion_only", + "completion_token_count": token_count, + "all_token_ids_valid": all_token_ids_valid, + } + header.update(metadata) + ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None + return logprobs, ids_out, header + + +def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: + """Decompress and unpack an LP/v1 payload into completion logprobs and token ids. + + Args: + data_b64: Base64-encoded zstd-compressed LP binary blob from + ``payloads.logprobs.data``. + + Returns: + ``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token + scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``, + and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``. + """ + compressed = base64.b64decode(data_b64) + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed) + return parse_logprobs(raw) + + +def decode_logprobs(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.logprobs.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-completion-token logprob list; per-token ids (when all + valid) are available under ``token_ids``. + """ + logprobs, token_ids, metadata = decompress_and_parse_lp(data_b64) + return DecodedPayload( + payload_type=PayloadType.LOGPROBS, + value=logprobs, + metadata=metadata, + token_ids=token_ids, + ) diff --git a/eval_protocol/tracing/prompt_token_ids.py b/eval_protocol/tracing/prompt_token_ids.py new file mode 100644 index 00000000..ff9d68de --- /dev/null +++ b/eval_protocol/tracing/prompt_token_ids.py @@ -0,0 +1,34 @@ +"""``pti/v1`` decoder for prompt token ID payloads. + +Inverse of the tracing gateway's ``serialize_prompt_token_ids``: the gateway +stores prompt token IDs as ``base64(zstd(json.dumps(token_ids)))`` -- a compact +JSON int array, no bespoke binary header. +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any, Dict, List, Tuple + +import zstandard as zstd + +from .types import DecodedPayload, PayloadType + + +def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]: + """Parse uncompressed ``pti/v1`` bytes (a JSON int array) into ids + metadata.""" + token_ids = json.loads(raw) + metadata: Dict[str, Any] = {"scope": "prompt_only", "token_count": len(token_ids)} + return token_ids, metadata + + +def decode_prompt_token_ids(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.prompt_token_ids.data`` blob.""" + raw = zstd.ZstdDecompressor().decompress(base64.b64decode(data_b64)) + token_ids, metadata = parse_prompt_token_ids(raw) + return DecodedPayload( + payload_type=PayloadType.PROMPT_TOKEN_IDS, + value=token_ids, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/registry.py b/eval_protocol/tracing/registry.py new file mode 100644 index 00000000..62027376 --- /dev/null +++ b/eval_protocol/tracing/registry.py @@ -0,0 +1,88 @@ +"""Decoder registry + master decode for tracing-gateway payloads. + +Adding a new payload type is a single entry in ``PAYLOAD_DECODERS`` (plus its +decoder module). Callers use the master :func:`decode_payloads` / +:func:`decode_trace` and never stitch per-type decoders together. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Optional + +from .logprobs import decode_logprobs +from .prompt_token_ids import decode_prompt_token_ids +from .router_replay import decode_router_replay +from .types import DecodedPayload, PayloadType + +logger = logging.getLogger(__name__) + +# Callback invoked when a single payload fails to decode: (payload_type, exc). +OnError = Callable[[PayloadType, Exception], None] + +PAYLOAD_DECODERS: Dict[PayloadType, Callable[[str], DecodedPayload]] = { + PayloadType.PROMPT_TOKEN_IDS: decode_prompt_token_ids, + PayloadType.LOGPROBS: decode_logprobs, + PayloadType.ROUTER_REPLAY: decode_router_replay, +} + + +def decode_payload(payload_type: PayloadType | str, data_b64: str) -> DecodedPayload: + """Decode a single payload by type. + + ``payload_type`` accepts a ``PayloadType`` or its string value, so external + callers can pass either. + """ + ptype = PayloadType(payload_type) + decoder = PAYLOAD_DECODERS.get(ptype) + if decoder is None: + raise ValueError(f"No decoder registered for payload type: {ptype!r}") + return decoder(data_b64) + + +def decode_payloads( + payloads: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Master decode: run every registered decoder over a gateway ``payloads`` dict. + + Args: + payloads: The ``payloads`` object from a gateway trace (i.e. + ``trace["payloads"]``), mapping payload-type name -> ``{"manifest", "data"}``. + on_error: Optional callback ``(payload_type, exc)`` invoked when a present + payload fails to decode. Defaults to logging a warning. A failure in + one payload never blocks the others. + + Returns: + ``{PayloadType: DecodedPayload}`` for every payload that is present and + decodes successfully. Only known ``PayloadType`` members are considered, + so unknown payload types from a newer gateway are ignored rather than + raising. + """ + if not isinstance(payloads, dict): + return {} + + decoded: Dict[PayloadType, DecodedPayload] = {} + for ptype, decoder in PAYLOAD_DECODERS.items(): + entry = payloads.get(ptype) + if not isinstance(entry, dict) or not entry.get("data"): + continue + try: + decoded[ptype] = decoder(entry["data"]) + except Exception as exc: # noqa: BLE001 - isolate per-payload failures + if on_error is not None: + on_error(ptype, exc) + else: + logger.warning("Failed to decode %s payload: %s", ptype.value, exc) + return decoded + + +def decode_trace( + trace: Dict[str, Any], + *, + on_error: Optional[OnError] = None, +) -> Dict[PayloadType, DecodedPayload]: + """Convenience wrapper around :func:`decode_payloads` for a raw trace dict.""" + payloads = trace.get("payloads") if isinstance(trace, dict) else None + return decode_payloads(payloads or {}, on_error=on_error) diff --git a/eval_protocol/tracing/router_replay.py b/eval_protocol/tracing/router_replay.py new file mode 100644 index 00000000..49c4be8e --- /dev/null +++ b/eval_protocol/tracing/router_replay.py @@ -0,0 +1,202 @@ +"""R3/v1 binary deserializer for router-replay payloads. + +Implements the inverse of the packed binary format produced by the tracing +gateway's ``r3_serializer.serialize_r3``. See that module for the full +header specification. + +The main entry point is :func:`decompress_and_parse_r3`, which accepts the +base64-encoded compressed blob returned by the gateway's +``/v1/traces/pointwise?include_payloads=true`` endpoint and produces +per-token routing matrices in the same ``List[Optional[str]]`` format used +by the direct inference path (``DeploymentSampler.sample_with_tokens()``). +""" + +from __future__ import annotations + +import base64 +import struct +from enum import IntEnum +from typing import Any, Dict, List, Optional, Tuple + +import zstandard as zstd + +from .types import DecodedPayload, PayloadType + +MAGIC = b"R3V1" +HEADER_FORMAT = "<4sBBBBIIIIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes +BITS_PER_BYTE = 8 + + +class _SelectorMode(IntEnum): + ALL = 0 + SUFFIX = 1 + BITMAP = 2 + + +class _RoutingDtype(IntEnum): + UINT8 = 1 + UINT16 = 2 + + +_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} +_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} + + +def _parse_header(raw: bytes) -> Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError( + f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}" + ) + + ( + magic, + version, + selector_mode, + routing_dtype, + flags, + total_token_count, + replayed_token_count, + replay_start_token, + selector_byte_length, + matrix_byte_length, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad R3 magic: {magic!r}") + if version != 1: + raise ValueError(f"Unsupported R3 header version: {version}") + + return { + "selector_mode": selector_mode, + "routing_dtype": routing_dtype, + "flags": flags, + "total_token_count": total_token_count, + "replayed_token_count": replayed_token_count, + "replay_start_token": replay_start_token, + "selector_byte_length": selector_byte_length, + "matrix_byte_length": matrix_byte_length, + } + + +def _read_bitmap_positions( + selector_bytes: bytes, total_token_count: int +) -> List[int]: + """Return sorted token indices where the bitmap bit is set.""" + positions: List[int] = [] + for i in range(total_token_count): + byte_idx = i // BITS_PER_BYTE + bit_idx = i % BITS_PER_BYTE + if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: + positions.append(i) + return positions + + +def decompress_and_parse_r3( + data_b64: str, +) -> Tuple[List[Optional[str]], Dict[str, Any]]: + """Decompress and unpack an R3/v1 payload into per-token routing matrices. + + Args: + data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned + by the tracing gateway in ``payloads.router_replay.data``. + + Returns: + A tuple of ``(routing_matrices, metadata)`` where: + + - ``routing_matrices`` is a ``List[Optional[str]]`` of length + ``total_token_count``. Each present position contains a + base64-encoded routing matrix (matching the format returned by + the direct inference path); absent positions are ``None``. + - ``metadata`` is a dict with keys ``routing_dtype``, + ``selector_mode``, ``total_token_count``, ``replayed_token_count``, + ``replay_start_token``. + """ + compressed = base64.b64decode(data_b64) + + # ZstdCompressor.compress() embeds the uncompressed size in the frame + # header by default, so the library can auto-allocate the output buffer. + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed) + + header = _parse_header(raw) + + selector_mode = header["selector_mode"] + routing_dtype = header["routing_dtype"] + total_token_count = header["total_token_count"] + replayed_token_count = header["replayed_token_count"] + replay_start_token = header["replay_start_token"] + selector_byte_length = header["selector_byte_length"] + matrix_byte_length = header["matrix_byte_length"] + + metadata: Dict[str, Any] = { + "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), + "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), + "total_token_count": total_token_count, + "replayed_token_count": replayed_token_count, + "replay_start_token": replay_start_token, + } + + if replayed_token_count == 0: + return [None] * total_token_count, metadata + + # Per-token matrix byte size is implicit in the payload: all replayed + # tokens share the same matrix length, so we can recover it from the + # matrix section total length divided by the replayed-token count. + if matrix_byte_length % replayed_token_count != 0: + raise ValueError( + f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " + f"replayed_token_count ({replayed_token_count}); cannot split " + "into per-token matrices" + ) + matrix_elem_size = matrix_byte_length // replayed_token_count + + body = raw[HEADER_SIZE:] + expected_body_length = selector_byte_length + matrix_byte_length + if len(body) < expected_body_length: + raise ValueError( + f"Payload body too short for selector and matrix sections: " + f"{len(body)} < {expected_body_length}" + ) + + selector_bytes = body[:selector_byte_length] + matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] + + if selector_mode == _SelectorMode.ALL: + replayed_positions = list(range(total_token_count)) + elif selector_mode == _SelectorMode.SUFFIX: + replayed_positions = list( + range(replay_start_token, replay_start_token + replayed_token_count) + ) + elif selector_mode == _SelectorMode.BITMAP: + replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count) + else: + raise ValueError(f"Unknown selector_mode: {selector_mode}") + + if len(replayed_positions) != replayed_token_count: + raise ValueError( + f"Selector produced {len(replayed_positions)} replayed positions, " + f"but header replayed_token_count is {replayed_token_count}" + ) + + # Split matrix bytes into per-token chunks and base64-encode each one + matrices: List[Optional[str]] = [None] * total_token_count + for idx, pos in enumerate(replayed_positions): + start = idx * matrix_elem_size + end = start + matrix_elem_size + matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") + + return matrices, metadata + + +def decode_router_replay(data_b64: str) -> DecodedPayload: + """Decode a gateway ``payloads.router_replay.data`` blob into a ``DecodedPayload``. + + ``value`` is the per-token ``List[Optional[str]]`` of base64 routing matrices. + """ + matrices, metadata = decompress_and_parse_r3(data_b64) + return DecodedPayload( + payload_type=PayloadType.ROUTER_REPLAY, + value=matrices, + metadata=metadata, + ) diff --git a/eval_protocol/tracing/types.py b/eval_protocol/tracing/types.py new file mode 100644 index 00000000..a11f6412 --- /dev/null +++ b/eval_protocol/tracing/types.py @@ -0,0 +1,41 @@ +"""Shared types for the Fireworks tracing-gateway payload decoders.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class PayloadType(str, Enum): + """Known out-of-band trace payload types emitted by the tracing gateway. + + Canonical source of truth for payload-type names on the EP side. Mirrors the + gateway's ``rft_tracing.schemas.PayloadType`` but is defined locally so this + package has no dependency on the gateway/mono codebase. Being a ``str`` enum, + members compare and hash equal to their string value, so they can be used + directly against the gateway's string-keyed ``payloads`` JSON. + """ + + PROMPT_TOKEN_IDS = "prompt_token_ids" + LOGPROBS = "logprobs" + ROUTER_REPLAY = "router_replay" + + +@dataclass(frozen=True) +class DecodedPayload: + """A single decoded gateway payload. + + ``value`` shape depends on ``payload_type``: + - ``PROMPT_TOKEN_IDS`` -> ``List[int]`` + - ``LOGPROBS`` -> ``List[float]`` (per completion token); the + optional per-token ids are in ``token_ids`` + - ``ROUTER_REPLAY`` -> ``List[Optional[str]]`` (per-token base64 routing + matrices, ``None`` where absent) + """ + + payload_type: PayloadType + value: Any + metadata: Dict[str, Any] + # LOGPROBS only: per-completion-token ids, or None if any were missing. + token_ids: Optional[List[int]] = None diff --git a/scripts/test_remote_rollout_prompt_token_ids.py b/scripts/test_remote_rollout_prompt_token_ids.py new file mode 100644 index 00000000..a143772f --- /dev/null +++ b/scripts/test_remote_rollout_prompt_token_ids.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads. + +This starts a tiny local `/init` server, sends one chat completion through the +Fireworks tracing gateway with `return_token_ids`, and verifies that +RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import socket +import threading +import time +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI + +from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger("remote_rollout_prompt_token_ids") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]: + if isinstance(message, Message): + return message.dump_mdoel_for_chat_completion_request() + return {k: v for k, v in dict(message).items() if v is not None} + + +def _make_app(gateway_url: str) -> FastAPI: + app = FastAPI() + app_logger = logging.getLogger(f"{__name__}.server") + app_logger.setLevel(logging.INFO) + + @app.get("/") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/init") + def init(req: InitRequest) -> dict[str, str]: + rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers): + rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url)) + rollout_logger.setLevel(logging.INFO) + + def _worker() -> None: + try: + conversation = [_message_to_dict(message) for message in (req.messages or [])] + params = dict(req.completion_params or {}) + params.pop("base_url", None) + params["extra_body"] = { + **dict(params.get("extra_body") or {}), + "return_token_ids": True, + } + params.setdefault("temperature", 0) + params.setdefault("max_tokens", 8) + + if not req.model_base_url: + raise ValueError("model_base_url is required") + if not params.get("model"): + raise ValueError("completion_params.model is required") + + client = OpenAI(base_url=req.model_base_url, api_key=req.api_key) + response = client.chat.completions.create(messages=conversation, **params) + content = response.choices[0].message.content or "" + logger.info("remote server generated content=%r", content) + + rollout_logger.info( + "rollout %s finished", + req.metadata.rollout_id, + extra={"status": Status.rollout_finished()}, + ) + except Exception as exc: + rollout_logger.exception( + "rollout %s failed", + req.metadata.rollout_id, + extra={"status": Status.rollout_unknown_error(str(exc))}, + ) + + threading.Thread(target=_worker, daemon=True).start() + return {"status": "started"} + + return app + + +def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None: + import requests + + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + return + except Exception: + pass + time.sleep(0.2) + raise TimeoutError(f"server not ready: {url}") + + +async def _run(args: argparse.Namespace) -> None: + api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY") + if not api_key: + raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY") + + # FireworksTracingHttpHandler reads FIREWORKS_API_KEY. + os.environ["FIREWORKS_API_KEY"] = api_key + os.environ["EP_REMOTE_API_KEY"] = api_key + + port = args.port or _free_port() + remote_base_url = f"http://127.0.0.1:{port}" + app = _make_app(args.gateway_url) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + _wait_ready(f"{remote_base_url}/") + + rollout_id = f"rrp-prompt-ids-{int(time.time())}" + row = EvaluationRow( + messages=[Message(role="user", content="Reply with exactly: ok")], + ) + row.input_metadata.row_id = "row-0" + row.input_metadata.completion_params = { + "model": args.model, + "base_url": args.api_base_url, + "temperature": 0, + "max_tokens": 8, + } + row.execution_metadata.rollout_id = rollout_id + row.execution_metadata.invocation_id = "inv-0" + row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e" + row.execution_metadata.run_id = "run-0" + + processor = RemoteRolloutProcessor( + remote_base_url=remote_base_url, + model_base_url=args.gateway_url, + include_payloads=True, + timeout_seconds=args.timeout_seconds, + poll_interval=args.poll_interval, + ) + try: + task = processor( + [row], + RolloutProcessorConfig( + completion_params=row.input_metadata.completion_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(1), + steps=1, + ), + )[0] + completed = await task + finally: + await processor.acleanup() + server.should_exit = True + thread.join(timeout=5) + + extra = completed.execution_metadata.extra or {} + turn_payloads = extra.get("assistant_turn_payloads") or [] + prompt_ids = None + if turn_payloads: + prompt_ids = turn_payloads[0].get("prompt_token_ids") + if prompt_ids is None: + prompt_ids = extra.get("prompt_token_ids") + + print(f"rollout_id={rollout_id}") + print(f"messages={len(completed.messages)}") + print(f"assistant_turn_payloads={turn_payloads}") + print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}") + print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}") + + if not isinstance(prompt_ids, list) or not prompt_ids: + raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app")) + parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1")) + parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly")) + parser.add_argument("--api-key", default=None) + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--timeout-seconds", type=float, default=180.0) + parser.add_argument("--poll-interval", type=float, default=2.0) + asyncio.run(_run(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/tests/adapters/test_deserializer_shims.py b/tests/adapters/test_deserializer_shims.py new file mode 100644 index 00000000..dbf5c999 --- /dev/null +++ b/tests/adapters/test_deserializer_shims.py @@ -0,0 +1,38 @@ +"""Backward-compat shims for moved adapter deserializers.""" + +from __future__ import annotations + +import warnings + +import pytest + +from tests.adapters.test_lp_deserializer import GOLDEN_RAW_HEX + + +def test_lp_deserializer_shim_reexports_and_warns(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + from eval_protocol.adapters import lp_deserializer as shim + + assert any( + "eval_protocol.adapters.lp_deserializer is deprecated" in str(w.message) + for w in caught + ) + raw = bytes.fromhex(GOLDEN_RAW_HEX) + logprobs, token_ids, metadata = shim.parse_logprobs(raw) + assert logprobs == [-0.25, -0.5] + assert token_ids == [7, 8] + assert metadata["completion_token_count"] == 2 + + +def test_r3_deserializer_shim_reexports_and_warns(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + from eval_protocol.adapters import r3_deserializer as shim + + assert any( + "eval_protocol.adapters.r3_deserializer is deprecated" in str(w.message) + for w in caught + ) + assert shim.MAGIC == b"R3V1" + assert shim._SelectorMode.ALL == 0 diff --git a/tests/adapters/test_fireworks_tracing_logprobs.py b/tests/adapters/test_fireworks_tracing_logprobs.py index 08dab60b..c38c7acd 100644 --- a/tests/adapters/test_fireworks_tracing_logprobs.py +++ b/tests/adapters/test_fireworks_tracing_logprobs.py @@ -11,7 +11,7 @@ pytest.importorskip("mcp") from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_fireworks_tracing_prompt_token_ids.py b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py new file mode 100644 index 00000000..16bebe3f --- /dev/null +++ b/tests/adapters/test_fireworks_tracing_prompt_token_ids.py @@ -0,0 +1,55 @@ +"""Tests for prompt token ID payload handling in fireworks_tracing adapter.""" + +from __future__ import annotations + +import base64 +import json + +import pytest +import zstandard as zstd + +pytest.importorskip("mcp") + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.tracing.prompt_token_ids import decode_prompt_token_ids + + +def _pti_b64(token_ids: list[int]) -> str: + """Build a gateway pti/v1 payload: base64(zstd(json int array)).""" + raw = json.dumps(token_ids).encode("utf-8") + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +def test_decode_prompt_token_ids_round_trip(): + decoded = decode_prompt_token_ids(_pti_b64([101, 102, 103])) + + assert decoded.value == [101, 102, 103] + assert decoded.metadata["scope"] == "prompt_only" + assert decoded.metadata["token_count"] == 3 + + +def test_trace_adapter_attaches_prompt_token_ids_metadata(): + trace = { + "id": "trace-pti", + "input": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, + "output": {"role": "assistant", "content": "hello"}, + "payloads": { + "prompt_token_ids": { + "data": _pti_b64([201, 202, 203]), + "manifest": {"PayloadVersion": "pti/v1"}, + }, + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + + assert row is not None + extra = row.execution_metadata.extra + assert extra is not None + assert extra["prompt_token_ids"] == [201, 202, 203] + assert extra["prompt_token_ids_metadata"]["token_count"] == 3 diff --git a/tests/adapters/test_lp_deserializer.py b/tests/adapters/test_lp_deserializer.py index 52e04417..c5460614 100644 --- a/tests/adapters/test_lp_deserializer.py +++ b/tests/adapters/test_lp_deserializer.py @@ -8,7 +8,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.lp_deserializer import ( +from eval_protocol.tracing.logprobs import ( ENTRY_FORMAT, ENTRY_SIZE, HEADER_FORMAT, diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 31f058f6..980d3934 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -10,7 +10,7 @@ import pytest import zstandard as zstd -from eval_protocol.adapters.r3_deserializer import ( +from eval_protocol.tracing.router_replay import ( HEADER_FORMAT, HEADER_SIZE, MAGIC, diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py index 58ec55c1..98246ab1 100644 --- a/tests/pytest/test_tracing_utils.py +++ b/tests/pytest/test_tracing_utils.py @@ -13,6 +13,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -28,6 +29,7 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): execution_metadata=ExecutionMetadata( extra={ "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, @@ -45,12 +47,14 @@ def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): { "assistant_turn_index": 0, "completion_logprobs": [-0.1, -0.2], + "prompt_token_ids": [101, 102], "routing_matrices": ["first-matrix"], "routing_metadata": {"total_token_count": 1}, }, { "assistant_turn_index": 1, "completion_logprobs": [-0.3], + "prompt_token_ids": [101, 102, 103, 104], "routing_matrices": ["second-matrix"], "routing_metadata": {"total_token_count": 1}, }, diff --git a/tests/tracing/test_registry.py b/tests/tracing/test_registry.py new file mode 100644 index 00000000..e32492db --- /dev/null +++ b/tests/tracing/test_registry.py @@ -0,0 +1,124 @@ +"""Tests for the standalone tracing-gateway payload decoder registry.""" + +from __future__ import annotations + +import base64 +import json +import struct + +import zstandard as zstd + +from eval_protocol.tracing import ( + DecodedPayload, + PayloadType, + decode_payload, + decode_payloads, + decode_trace, +) +from eval_protocol.tracing import logprobs as lp_mod +from eval_protocol.tracing import router_replay as r3_mod + + +def _b64_zstd(raw: bytes) -> str: + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +def _pti_data(token_ids: list[int]) -> str: + return _b64_zstd(json.dumps(token_ids).encode("utf-8")) + + +def _lp_data(tokens: list[tuple[int, float]]) -> str: + body = b"".join(struct.pack(lp_mod.ENTRY_FORMAT, tid, lp) for tid, lp in tokens) + header = struct.pack( + lp_mod.HEADER_FORMAT, lp_mod.MAGIC, 1, 0, 0, len(tokens), len(body), 0 + ) + return _b64_zstd(header + body) + + +def _r3_data_all_mode(matrices: list[bytes]) -> str: + matrix_data = b"".join(matrices) + header = struct.pack( + r3_mod.HEADER_FORMAT, + r3_mod.MAGIC, + 1, # version + r3_mod._SelectorMode.ALL, + r3_mod._RoutingDtype.UINT8, + 0x01, # flags + len(matrices), # total_token_count + len(matrices), # replayed_token_count + 0, # replay_start_token + 0, # selector_byte_length + len(matrix_data), + ) + return _b64_zstd(header + matrix_data) + + +def _all_payloads() -> dict: + return { + "prompt_token_ids": {"manifest": {"PayloadVersion": "pti/v1"}, "data": _pti_data([1, 2, 3])}, + "logprobs": {"manifest": {"PayloadVersion": "lp/v1"}, "data": _lp_data([(7, -0.25), (8, -0.5)])}, + "router_replay": { + "manifest": {"PayloadVersion": "r3/v1"}, + "data": _r3_data_all_mode([b"\x01\x02\x03\x04", b"\x05\x06\x07\x08"]), + }, + } + + +def test_decode_payloads_all_types(): + decoded = decode_payloads(_all_payloads()) + + assert set(decoded) == { + PayloadType.PROMPT_TOKEN_IDS, + PayloadType.LOGPROBS, + PayloadType.ROUTER_REPLAY, + } + assert all(isinstance(dp, DecodedPayload) for dp in decoded.values()) + + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [1, 2, 3] + + lp = decoded[PayloadType.LOGPROBS] + assert lp.value == [-0.25, -0.5] + assert lp.token_ids == [7, 8] + + r3 = decoded[PayloadType.ROUTER_REPLAY] + assert len(r3.value) == 2 + assert base64.b64decode(r3.value[0]) == b"\x01\x02\x03\x04" + + +def test_decode_payload_accepts_str_and_enum(): + data = _pti_data([10, 20]) + via_enum = decode_payload(PayloadType.PROMPT_TOKEN_IDS, data) + via_str = decode_payload("prompt_token_ids", data) + assert via_enum.value == via_str.value == [10, 20] + + +def test_decode_trace_reaches_into_payloads(): + trace = {"id": "t1", "payloads": {"prompt_token_ids": {"data": _pti_data([5, 6])}}} + decoded = decode_trace(trace) + assert decoded[PayloadType.PROMPT_TOKEN_IDS].value == [5, 6] + + +def test_unknown_and_empty_types_are_skipped(): + payloads = { + "some_future_type": {"data": "ignored"}, # unknown -> ignored + "logprobs": {"data": ""}, # present but empty -> skipped + "prompt_token_ids": {"data": _pti_data([9])}, + } + decoded = decode_payloads(payloads) + assert set(decoded) == {PayloadType.PROMPT_TOKEN_IDS} + + +def test_on_error_fires_on_bad_data(): + errors = [] + payloads = {"prompt_token_ids": {"data": "not-valid-base64-zstd-json!!"}} + + decoded = decode_payloads(payloads, on_error=lambda pt, e: errors.append((pt, e))) + + assert decoded == {} + assert len(errors) == 1 + assert errors[0][0] == PayloadType.PROMPT_TOKEN_IDS + + +def test_decode_payloads_non_dict_returns_empty(): + assert decode_payloads(None) == {} + assert decode_trace({"id": "no-payloads"}) == {}