diff --git a/.github/workflows/ci-release-uds-tokenizer.yaml b/.github/workflows/ci-release-uds-tokenizer.yaml index 5cc7f988..544971b5 100644 --- a/.github/workflows/ci-release-uds-tokenizer.yaml +++ b/.github/workflows/ci-release-uds-tokenizer.yaml @@ -34,8 +34,7 @@ jobs: image-name: llm-d-uds-tokenizer registry: ghcr.io/${{ github.repository_owner }} github-token: ${{ secrets.GHCR_TOKEN }} - context: services/uds_tokenizer - dockerfile: services/uds_tokenizer/Dockerfile + dockerfile: Dockerfile.tokenizer platform: linux/amd64 - name: Run Trivy scan diff --git a/Dockerfile.tokenizer b/Dockerfile.tokenizer new file mode 100644 index 00000000..30ab247d --- /dev/null +++ b/Dockerfile.tokenizer @@ -0,0 +1,69 @@ +# Copyright 2025 The llm-d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Build stage +FROM python:3.12-slim AS python-builder + +ARG TARGETOS=linux +ARG TARGETARCH=amd64 + +WORKDIR /workspace + +RUN apt-get update && apt-get install -y --no-install-recommends build-essential + +COPY Makefile Makefile +COPY pkg/preprocessing/chat_completions/ pkg/preprocessing/chat_completions/ +COPY services/uds_tokenizer/pyproject.toml services/uds_tokenizer/pyproject.toml +RUN TARGETOS=${TARGETOS} TARGETARCH=${TARGETARCH} make install-python-deps + +# Runtime stage +FROM python:3.12-slim + +# Set working directory +WORKDIR /app + +RUN apt-get update && apt-get upgrade -y && rm -rf /var/cache/apt/ + +# Copy installed dependencies from python-builder stage +COPY --from=python-builder /workspace/build/venv /app/venv +ENV PATH="/app/venv/bin:$PATH" +ENV PYTHONPATH="/app/preprocessing/chat_completions:/app/venv/lib/python3.12/site-packages" + +# Copy project files into the image +COPY services/uds_tokenizer/run_grpc_server.py /app/ +COPY services/uds_tokenizer/pyproject.toml /app/pyproject.toml +COPY services/uds_tokenizer/tokenizer_grpc_service.py /app/tokenizer_grpc_service.py +COPY services/uds_tokenizer/utils/ /app/utils/ +COPY services/uds_tokenizer/tokenizer_service/ /app/tokenizer_service/ +COPY services/uds_tokenizer/tokenizerpb/ /app/tokenizerpb/ + +# Copy the shared Python code for chat completion preprocessing from the project structure +RUN mkdir -p /app/preprocessing/chat_completions +COPY pkg/preprocessing/chat_completions/tokenizer_wrapper.py /app/preprocessing/chat_completions/ + +# Create directory for UDS socket +RUN mkdir -p /tmp/tokenizer && chown 65532:65532 /tmp/tokenizer + +# Create model cache directories and set permissions +RUN mkdir -p /app/models && chown -R 65532:65532 /app/models +# Create and set permissions for ModelScope directory +RUN mkdir -p /.modelscope && chown -R 65532:65532 /.modelscope +# Create and set permissions for Hugging Face cache directory +RUN mkdir -p /.cache && chown -R 65532:65532 /.cache + +# Switch to non-root user +USER 65532:65532 + +# Startup command: run direct gRPC server +CMD ["python", "/app/run_grpc_server.py"] \ No newline at end of file diff --git a/Dockerfile.uds b/Dockerfile.uds new file mode 100644 index 00000000..dae68337 --- /dev/null +++ b/Dockerfile.uds @@ -0,0 +1,59 @@ +# Copyright 2025 The llm-d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Build Stage: using Go 1.24.1 image +FROM quay.io/projectquay/golang:1.24 AS builder +ARG TARGETOS +ARG TARGETARCH + +WORKDIR /workspace + +# Install system-level dependencies first. This layer is very stable. +USER root +# Install EPEL repository directly and then ZeroMQ, as epel-release is not in default repos. +# Install all necessary dependencies including Python 3.12 for chat-completions templating. +# The builder is based on UBI8, so we need epel-release-8. +RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.noarch.rpm' && \ + dnf install -y gcc-c++ libstdc++ libstdc++-devel clang zeromq-devel pkgconfig && \ + dnf clean all + +# Copy the Go Modules manifests +COPY go.mod go.mod +COPY go.sum go.sum +# cache deps before building and copying source so that we don't need to re-download as much +# and so that source changes don't invalidate our downloaded layer +RUN go mod download + +# Copy the source code. +COPY . . + +RUN make build-uds + +# Use distroless as minimal base image to package the manager binary +# Refer to https://github.com/GoogleContainerTools/distroless for more details +FROM registry.access.redhat.com/ubi9/ubi:latest +WORKDIR / +# Install zeromq runtime library needed by the manager. +# The final image is UBI9, so we need epel-release-9. +USER root +RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm' && \ + dnf install -y zeromq libxcrypt-compat && \ + dnf clean all + +# Copy the compiled Go application +COPY --from=builder /workspace/bin/llm-d-kv-cache /app/kv-cache-manager +USER 65532:65532 + +# Set the entrypoint to the kv-cache-manager binary +ENTRYPOINT ["/app/kv-cache-manager"] diff --git a/Makefile b/Makefile index 813fb43f..ad1d38d1 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,10 @@ PYTHON_VERSION := 3.12 VENV_DIR := $(shell pwd)/build/venv VENV_BIN := $(VENV_DIR)/bin +UDS_TOKENIZER_DIR := services/uds_tokenizer +UDS_TOKENIZER_VENV_DIR := $(UDS_TOKENIZER_DIR)/.venv +UDS_TOKENIZER_VENV_BIN := $(UDS_TOKENIZER_VENV_DIR)/bin + # Attempt to find Python 3.9 executable. PYTHON_EXE := $(shell command -v python$(PYTHON_VERSION) || command -v python3) @@ -117,6 +121,8 @@ install-python-deps: setup-venv ## installs dependencies. echo "ERROR: Virtual environment not found. Run 'make setup-venv' first."; \ exit 1; \ fi + @echo "Installing UDS tokenizer Python dependencies..."; \ + $(VENV_BIN)/pip install "${UDS_TOKENIZER_DIR}" @if $(VENV_BIN)/python -c "import vllm" 2>/dev/null; then \ echo "vllm is already installed, skipping..."; \ exit 0; \ @@ -242,28 +248,22 @@ e2e-test-uds: check-go download-zmq image-build-uds ## Run UDS tokenizer e2e tes go test -v -count=1 -timeout 10m ./tests/e2e/uds_tokenizer/... ##@ UDS Tokenizer Python Tests -UDS_TOKENIZER_DIR := services/uds_tokenizer -UDS_TOKENIZER_VENV_DIR := $(UDS_TOKENIZER_DIR)/.venv -UDS_TOKENIZER_VENV_BIN := $(UDS_TOKENIZER_VENV_DIR)/bin - .PHONY: uds-tokenizer-install-deps -uds-tokenizer-install-deps: detect-python ## Set up venv and install UDS tokenizer dependencies - @printf "\033[33;1m==== Setting up UDS tokenizer venv and dependencies ====\033[0m\n" - @if [ ! -f "$(UDS_TOKENIZER_VENV_BIN)/python" ]; then \ - echo "Creating virtual environment in $(UDS_TOKENIZER_VENV_DIR)..."; \ - $(PYTHON_EXE) -m venv $(UDS_TOKENIZER_VENV_DIR); \ - echo "Upgrading pip..."; \ - $(UDS_TOKENIZER_VENV_BIN)/pip install --upgrade pip; \ +uds-tokenizer-install-deps: install-python-deps ## Set up venv and install UDS tokenizer dependencies + @printf "\033[33;1m==== Detecting UDS tokenizer venv and dependencies ====\033[0m\n" + @if [ ! -f "$(VENV_BIN)/python" ]; then \ + echo "Virtual environment not exist"; \ + exit 1; \ else \ echo "Virtual environment already exists"; \ fi - @echo "Installing dependencies..." - @$(UDS_TOKENIZER_VENV_BIN)/pip install "$(UDS_TOKENIZER_DIR)[test]" + @echo "Installing UDS tokenizer test dependencies..." + @$(VENV_BIN)/pip install "$(UDS_TOKENIZER_DIR)[test]" .PHONY: uds-tokenizer-service-test uds-tokenizer-service-test: uds-tokenizer-install-deps ## Run UDS tokenizer integration tests (starts server automatically) @printf "\033[33;1m==== Running UDS tokenizer integration tests ====\033[0m\n" - @$(UDS_TOKENIZER_VENV_BIN)/python -m pytest \ + @$(VENV_BIN)/python -m pytest \ $(UDS_TOKENIZER_DIR)/tests/test_integration.py \ -v --timeout=60 @@ -286,8 +286,8 @@ build: build-uds build-embedded ## Build both UDS-only and embedded binaries .PHONY: build-uds build-uds: check-go download-zmq ## Build without embedded tokenizers (no Python required) - @printf "\033[33;1m==== Building (UDS-only, no embedded tokenizers) ====\033[0m\n" - @go build ./pkg/... + @printf "\033[33;1m==== Building application binary (with uds tokenizers) ====\033[0m\n" + @go build -o bin/$(PROJECT_NAME) examples/kv_events/online_uds/main.go @echo "✅ UDS-only build succeeded" .PHONY: build-embedded diff --git a/api/tokenizerpb/tokenizer.pb.go b/api/tokenizerpb/tokenizer.pb.go index 76f00dc4..3e8d9603 100644 --- a/api/tokenizerpb/tokenizer.pb.go +++ b/api/tokenizerpb/tokenizer.pb.go @@ -35,291 +35,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// TokenizeRequest represents a request to tokenize a text input -type TokenizeRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Input string `protobuf:"bytes,1,opt,name=input,proto3" json:"input,omitempty"` // The text input to tokenize - ModelName string `protobuf:"bytes,2,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // The name of the model to use for tokenization - AddSpecialTokens bool `protobuf:"varint,3,opt,name=add_special_tokens,json=addSpecialTokens,proto3" json:"add_special_tokens,omitempty"` // Whether to add special tokens during tokenization - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TokenizeRequest) Reset() { - *x = TokenizeRequest{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TokenizeRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TokenizeRequest) ProtoMessage() {} - -func (x *TokenizeRequest) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TokenizeRequest.ProtoReflect.Descriptor instead. -func (*TokenizeRequest) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{0} -} - -func (x *TokenizeRequest) GetInput() string { - if x != nil { - return x.Input - } - return "" -} - -func (x *TokenizeRequest) GetModelName() string { - if x != nil { - return x.ModelName - } - return "" -} - -func (x *TokenizeRequest) GetAddSpecialTokens() bool { - if x != nil { - return x.AddSpecialTokens - } - return false -} - -// TokenizeResponse represents the response from tokenization -type TokenizeResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - InputIds []uint32 `protobuf:"varint,1,rep,packed,name=input_ids,json=inputIds,proto3" json:"input_ids,omitempty"` // Token IDs for the input - Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` // Whether the request was successful - ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` // Error message if the request failed - // Direct array of [start, end] pairs - OffsetPairs []uint32 `protobuf:"varint,4,rep,packed,name=offset_pairs,json=offsetPairs,proto3" json:"offset_pairs,omitempty"` // Flattened array of [start, end, start, end, ...] - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *TokenizeResponse) Reset() { - *x = TokenizeResponse{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *TokenizeResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*TokenizeResponse) ProtoMessage() {} - -func (x *TokenizeResponse) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use TokenizeResponse.ProtoReflect.Descriptor instead. -func (*TokenizeResponse) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{1} -} - -func (x *TokenizeResponse) GetInputIds() []uint32 { - if x != nil { - return x.InputIds - } - return nil -} - -func (x *TokenizeResponse) GetSuccess() bool { - if x != nil { - return x.Success - } - return false -} - -func (x *TokenizeResponse) GetErrorMessage() string { - if x != nil { - return x.ErrorMessage - } - return "" -} - -func (x *TokenizeResponse) GetOffsetPairs() []uint32 { - if x != nil { - return x.OffsetPairs - } - return nil -} - -// ConversationTurn represents a single turn in a conversation (a single message or multiple messages per turn) -type ConversationTurn struct { - state protoimpl.MessageState `protogen:"open.v1"` - Messages []*ChatMessage `protobuf:"bytes,1,rep,name=messages,proto3" json:"messages,omitempty"` // The messages in this turn - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *ConversationTurn) Reset() { - *x = ConversationTurn{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *ConversationTurn) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ConversationTurn) ProtoMessage() {} - -func (x *ConversationTurn) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ConversationTurn.ProtoReflect.Descriptor instead. -func (*ConversationTurn) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{2} -} - -func (x *ConversationTurn) GetMessages() []*ChatMessage { - if x != nil { - return x.Messages - } - return nil -} - -// ChatTemplateRequest represents a request to render a chat template -type ChatTemplateRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ConversationTurns []*ConversationTurn `protobuf:"bytes,1,rep,name=conversation_turns,json=conversationTurns,proto3" json:"conversation_turns,omitempty"` // The conversation turns (batches of messages) - Tools []*ToolDescription `protobuf:"bytes,2,rep,name=tools,proto3" json:"tools,omitempty"` // Tools available to the conversation - Documents []*Document `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"` // Documents related to the conversation - ChatTemplate string `protobuf:"bytes,4,opt,name=chat_template,json=chatTemplate,proto3" json:"chat_template,omitempty"` // The chat template to use - ReturnAssistantTokensMask bool `protobuf:"varint,5,opt,name=return_assistant_tokens_mask,json=returnAssistantTokensMask,proto3" json:"return_assistant_tokens_mask,omitempty"` // Whether to return assistant token mask - ContinueFinalMessage bool `protobuf:"varint,6,opt,name=continue_final_message,json=continueFinalMessage,proto3" json:"continue_final_message,omitempty"` // Whether to continue the final message - AddGenerationPrompt bool `protobuf:"varint,7,opt,name=add_generation_prompt,json=addGenerationPrompt,proto3" json:"add_generation_prompt,omitempty"` // Whether to add generation prompt - ChatTemplateKwargs map[string]*Value `protobuf:"bytes,8,rep,name=chat_template_kwargs,json=chatTemplateKwargs,proto3" json:"chat_template_kwargs,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Additional chat template arguments - ModelName string `protobuf:"bytes,9,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // The name of the model to use for tokenization - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *ChatTemplateRequest) Reset() { - *x = ChatTemplateRequest{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *ChatTemplateRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ChatTemplateRequest) ProtoMessage() {} - -func (x *ChatTemplateRequest) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ChatTemplateRequest.ProtoReflect.Descriptor instead. -func (*ChatTemplateRequest) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{3} -} - -func (x *ChatTemplateRequest) GetConversationTurns() []*ConversationTurn { - if x != nil { - return x.ConversationTurns - } - return nil -} - -func (x *ChatTemplateRequest) GetTools() []*ToolDescription { - if x != nil { - return x.Tools - } - return nil -} - -func (x *ChatTemplateRequest) GetDocuments() []*Document { - if x != nil { - return x.Documents - } - return nil -} - -func (x *ChatTemplateRequest) GetChatTemplate() string { - if x != nil { - return x.ChatTemplate - } - return "" -} - -func (x *ChatTemplateRequest) GetReturnAssistantTokensMask() bool { - if x != nil { - return x.ReturnAssistantTokensMask - } - return false -} - -func (x *ChatTemplateRequest) GetContinueFinalMessage() bool { - if x != nil { - return x.ContinueFinalMessage - } - return false -} - -func (x *ChatTemplateRequest) GetAddGenerationPrompt() bool { - if x != nil { - return x.AddGenerationPrompt - } - return false -} - -func (x *ChatTemplateRequest) GetChatTemplateKwargs() map[string]*Value { - if x != nil { - return x.ChatTemplateKwargs - } - return nil -} - -func (x *ChatTemplateRequest) GetModelName() string { - if x != nil { - return x.ModelName - } - return "" -} - // ChatMessage represents a single message in a conversation type ChatMessage struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -331,7 +46,7 @@ type ChatMessage struct { func (x *ChatMessage) Reset() { *x = ChatMessage{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[4] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -342,105 +57,8 @@ func (x *ChatMessage) String() string { func (*ChatMessage) ProtoMessage() {} -func (x *ChatMessage) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[4] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ChatMessage.ProtoReflect.Descriptor instead. -func (*ChatMessage) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{4} -} - -func (x *ChatMessage) GetRole() string { - if x != nil { - return x.Role - } - return "" -} - -func (x *ChatMessage) GetContent() string { - if x != nil { - return x.Content - } - return "" -} - -// ToolDescription represents a description of a tool -type ToolDescription struct { - state protoimpl.MessageState `protogen:"open.v1"` - Tool map[string]*Value `protobuf:"bytes,1,rep,name=tool,proto3" json:"tool,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Tool definition - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *ToolDescription) Reset() { - *x = ToolDescription{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *ToolDescription) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ToolDescription) ProtoMessage() {} - -func (x *ToolDescription) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[5] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ToolDescription.ProtoReflect.Descriptor instead. -func (*ToolDescription) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{5} -} - -func (x *ToolDescription) GetTool() map[string]*Value { - if x != nil { - return x.Tool - } - return nil -} - -// Document represents a document -type Document struct { - state protoimpl.MessageState `protogen:"open.v1"` - Document map[string]*Value `protobuf:"bytes,1,rep,name=document,proto3" json:"document,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Document definition - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Document) Reset() { - *x = Document{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Document) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Document) ProtoMessage() {} - -func (x *Document) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[6] +func (x *ChatMessage) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -451,16 +69,23 @@ func (x *Document) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Document.ProtoReflect.Descriptor instead. -func (*Document) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{6} +// Deprecated: Use ChatMessage.ProtoReflect.Descriptor instead. +func (*ChatMessage) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{0} +} + +func (x *ChatMessage) GetRole() string { + if x != nil { + return x.Role + } + return "" } -func (x *Document) GetDocument() map[string]*Value { +func (x *ChatMessage) GetContent() string { if x != nil { - return x.Document + return x.Content } - return nil + return "" } // Value represents a generic value that can be string, number, bool, or list @@ -480,7 +105,7 @@ type Value struct { func (x *Value) Reset() { *x = Value{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[7] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -492,7 +117,7 @@ func (x *Value) String() string { func (*Value) ProtoMessage() {} func (x *Value) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[7] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -505,7 +130,7 @@ func (x *Value) ProtoReflect() protoreflect.Message { // Deprecated: Use Value.ProtoReflect.Descriptor instead. func (*Value) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{7} + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{1} } func (x *Value) GetValue() isValue_Value { @@ -604,7 +229,7 @@ type ListValue struct { func (x *ListValue) Reset() { *x = ListValue{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[8] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -616,7 +241,7 @@ func (x *ListValue) String() string { func (*ListValue) ProtoMessage() {} func (x *ListValue) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[8] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -629,7 +254,7 @@ func (x *ListValue) ProtoReflect() protoreflect.Message { // Deprecated: Use ListValue.ProtoReflect.Descriptor instead. func (*ListValue) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{8} + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{2} } func (x *ListValue) GetValues() []*Value { @@ -649,7 +274,7 @@ type StructValue struct { func (x *StructValue) Reset() { *x = StructValue{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[9] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -661,7 +286,7 @@ func (x *StructValue) String() string { func (*StructValue) ProtoMessage() {} func (x *StructValue) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[9] + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -674,7 +299,7 @@ func (x *StructValue) ProtoReflect() protoreflect.Message { // Deprecated: Use StructValue.ProtoReflect.Descriptor instead. func (*StructValue) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{9} + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{3} } func (x *StructValue) GetFields() map[string]*Value { @@ -684,31 +309,33 @@ func (x *StructValue) GetFields() map[string]*Value { return nil } -// ChatTemplateResponse represents the response from rendering a chat template -type ChatTemplateResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - RenderedPrompt string `protobuf:"bytes,1,opt,name=rendered_prompt,json=renderedPrompt,proto3" json:"rendered_prompt,omitempty"` // The rendered chat template prompt - Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` // Whether the request was successful - ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` // Error message if the request failed - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache +// InitializeTokenizerRequest represents a request to initialize tokenizer for a model +type InitializeTokenizerRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + IsLocal bool `protobuf:"varint,1,opt,name=is_local,json=isLocal,proto3" json:"is_local,omitempty"` // Whether the model is local (default: true) + Model string `protobuf:"bytes,2,opt,name=model,proto3" json:"model,omitempty"` // The model ID or path (HF model ID, local directory path, or path to tokenizer file) + Revision *string `protobuf:"bytes,3,opt,name=revision,proto3,oneof" json:"revision,omitempty"` // Model revision (optional) + Token *string `protobuf:"bytes,4,opt,name=token,proto3,oneof" json:"token,omitempty"` // Hugging Face token for private models (optional) + DownloadDir *string `protobuf:"bytes,5,opt,name=download_dir,json=downloadDir,proto3,oneof" json:"download_dir,omitempty"` // Directory to download the model (optional) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } -func (x *ChatTemplateResponse) Reset() { - *x = ChatTemplateResponse{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[10] +func (x *InitializeTokenizerRequest) Reset() { + *x = InitializeTokenizerRequest{} + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *ChatTemplateResponse) String() string { +func (x *InitializeTokenizerRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ChatTemplateResponse) ProtoMessage() {} +func (*InitializeTokenizerRequest) ProtoMessage() {} -func (x *ChatTemplateResponse) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[10] +func (x *InitializeTokenizerRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -719,57 +346,124 @@ func (x *ChatTemplateResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ChatTemplateResponse.ProtoReflect.Descriptor instead. -func (*ChatTemplateResponse) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{10} +// Deprecated: Use InitializeTokenizerRequest.ProtoReflect.Descriptor instead. +func (*InitializeTokenizerRequest) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{4} +} + +func (x *InitializeTokenizerRequest) GetIsLocal() bool { + if x != nil { + return x.IsLocal + } + return false } -func (x *ChatTemplateResponse) GetRenderedPrompt() string { +func (x *InitializeTokenizerRequest) GetModel() string { if x != nil { - return x.RenderedPrompt + return x.Model + } + return "" +} + +func (x *InitializeTokenizerRequest) GetRevision() string { + if x != nil && x.Revision != nil { + return *x.Revision + } + return "" +} + +func (x *InitializeTokenizerRequest) GetToken() string { + if x != nil && x.Token != nil { + return *x.Token + } + return "" +} + +func (x *InitializeTokenizerRequest) GetDownloadDir() string { + if x != nil && x.DownloadDir != nil { + return *x.DownloadDir } return "" } -func (x *ChatTemplateResponse) GetSuccess() bool { +// InitializeTokenizerResponse represents the response from tokenizer initialization +type InitializeTokenizerResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` // Whether the initialization was successful + ErrorMessage string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` // Error message if initialization failed + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitializeTokenizerResponse) Reset() { + *x = InitializeTokenizerResponse{} + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitializeTokenizerResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitializeTokenizerResponse) ProtoMessage() {} + +func (x *InitializeTokenizerResponse) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitializeTokenizerResponse.ProtoReflect.Descriptor instead. +func (*InitializeTokenizerResponse) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{5} +} + +func (x *InitializeTokenizerResponse) GetSuccess() bool { if x != nil { return x.Success } return false } -func (x *ChatTemplateResponse) GetErrorMessage() string { +func (x *InitializeTokenizerResponse) GetErrorMessage() string { if x != nil { return x.ErrorMessage } return "" } -// InitializeTokenizerRequest represents a request to initialize tokenizer for a model -type InitializeTokenizerRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // The name of the model to initialize - EnableThinking bool `protobuf:"varint,2,opt,name=enable_thinking,json=enableThinking,proto3" json:"enable_thinking,omitempty"` // Whether to enable thinking tokens - AddGenerationPrompt bool `protobuf:"varint,3,opt,name=add_generation_prompt,json=addGenerationPrompt,proto3" json:"add_generation_prompt,omitempty"` // Whether to add generation prompt - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache +// RenderRequest represents a request to render (tokenize) a text input +type RenderRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` // The text input to render/tokenize + ModelName string `protobuf:"bytes,2,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // The name of the model to use for tokenization + AddSpecialTokens bool `protobuf:"varint,3,opt,name=add_special_tokens,json=addSpecialTokens,proto3" json:"add_special_tokens,omitempty"` // Whether to add special tokens during tokenization + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } -func (x *InitializeTokenizerRequest) Reset() { - *x = InitializeTokenizerRequest{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[11] +func (x *RenderRequest) Reset() { + *x = RenderRequest{} + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *InitializeTokenizerRequest) String() string { +func (x *RenderRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*InitializeTokenizerRequest) ProtoMessage() {} +func (*RenderRequest) ProtoMessage() {} -func (x *InitializeTokenizerRequest) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[11] +func (x *RenderRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -780,56 +474,168 @@ func (x *InitializeTokenizerRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use InitializeTokenizerRequest.ProtoReflect.Descriptor instead. -func (*InitializeTokenizerRequest) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{11} +// Deprecated: Use RenderRequest.ProtoReflect.Descriptor instead. +func (*RenderRequest) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{6} +} + +func (x *RenderRequest) GetText() string { + if x != nil { + return x.Text + } + return "" } -func (x *InitializeTokenizerRequest) GetModelName() string { +func (x *RenderRequest) GetModelName() string { if x != nil { return x.ModelName } return "" } -func (x *InitializeTokenizerRequest) GetEnableThinking() bool { +func (x *RenderRequest) GetAddSpecialTokens() bool { if x != nil { - return x.EnableThinking + return x.AddSpecialTokens } return false } -func (x *InitializeTokenizerRequest) GetAddGenerationPrompt() bool { +// RenderChatRequest represents a request to render a chat conversation +type RenderChatRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Conversation []*ChatMessage `protobuf:"bytes,1,rep,name=conversation,proto3" json:"conversation,omitempty"` // The conversation messages (array of role/content pairs) + Tools []*Value `protobuf:"bytes,2,rep,name=tools,proto3" json:"tools,omitempty"` // Tools available to the conversation (arbitrary values) + Documents []*Value `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"` // Documents related to the conversation (arbitrary values) + ChatTemplate *string `protobuf:"bytes,4,opt,name=chat_template,json=chatTemplate,proto3,oneof" json:"chat_template,omitempty"` // The chat template to use + ReturnAssistantTokensMask *bool `protobuf:"varint,5,opt,name=return_assistant_tokens_mask,json=returnAssistantTokensMask,proto3,oneof" json:"return_assistant_tokens_mask,omitempty"` // Whether to return assistant token mask + ContinueFinalMessage *bool `protobuf:"varint,6,opt,name=continue_final_message,json=continueFinalMessage,proto3,oneof" json:"continue_final_message,omitempty"` // Whether to continue the final message + AddGenerationPrompt *bool `protobuf:"varint,7,opt,name=add_generation_prompt,json=addGenerationPrompt,proto3,oneof" json:"add_generation_prompt,omitempty"` // Whether to add generation prompt + ChatTemplateKwargs map[string]*Value `protobuf:"bytes,8,rep,name=chat_template_kwargs,json=chatTemplateKwargs,proto3" json:"chat_template_kwargs,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // Additional chat template arguments + ModelName string `protobuf:"bytes,9,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` // The name of the model to use for tokenization + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RenderChatRequest) Reset() { + *x = RenderChatRequest{} + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RenderChatRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenderChatRequest) ProtoMessage() {} + +func (x *RenderChatRequest) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenderChatRequest.ProtoReflect.Descriptor instead. +func (*RenderChatRequest) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{7} +} + +func (x *RenderChatRequest) GetConversation() []*ChatMessage { + if x != nil { + return x.Conversation + } + return nil +} + +func (x *RenderChatRequest) GetTools() []*Value { + if x != nil { + return x.Tools + } + return nil +} + +func (x *RenderChatRequest) GetDocuments() []*Value { if x != nil { - return x.AddGenerationPrompt + return x.Documents + } + return nil +} + +func (x *RenderChatRequest) GetChatTemplate() string { + if x != nil && x.ChatTemplate != nil { + return *x.ChatTemplate + } + return "" +} + +func (x *RenderChatRequest) GetReturnAssistantTokensMask() bool { + if x != nil && x.ReturnAssistantTokensMask != nil { + return *x.ReturnAssistantTokensMask } return false } -// InitializeTokenizerResponse represents the response from tokenizer initialization -type InitializeTokenizerResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` // Whether the initialization was successful - ErrorMessage string `protobuf:"bytes,2,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` // Error message if initialization failed +func (x *RenderChatRequest) GetContinueFinalMessage() bool { + if x != nil && x.ContinueFinalMessage != nil { + return *x.ContinueFinalMessage + } + return false +} + +func (x *RenderChatRequest) GetAddGenerationPrompt() bool { + if x != nil && x.AddGenerationPrompt != nil { + return *x.AddGenerationPrompt + } + return false +} + +func (x *RenderChatRequest) GetChatTemplateKwargs() map[string]*Value { + if x != nil { + return x.ChatTemplateKwargs + } + return nil +} + +func (x *RenderChatRequest) GetModelName() string { + if x != nil { + return x.ModelName + } + return "" +} + +// RenderResponse represents the response from rendering/tokenizing +type RenderResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + InputIds []uint32 `protobuf:"varint,1,rep,packed,name=input_ids,json=inputIds,proto3" json:"input_ids,omitempty"` // Token IDs for the input + Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` // Whether the request was successful + ErrorMessage string `protobuf:"bytes,3,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"` // Error message if the request failed + // Direct array of [start, end] pairs + OffsetPairs []uint32 `protobuf:"varint,4,rep,packed,name=offset_pairs,json=offsetPairs,proto3" json:"offset_pairs,omitempty"` // Flattened array of [start, end, start, end, ...] unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *InitializeTokenizerResponse) Reset() { - *x = InitializeTokenizerResponse{} - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[12] +func (x *RenderResponse) Reset() { + *x = RenderResponse{} + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *InitializeTokenizerResponse) String() string { +func (x *RenderResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*InitializeTokenizerResponse) ProtoMessage() {} +func (*RenderResponse) ProtoMessage() {} -func (x *InitializeTokenizerResponse) ProtoReflect() protoreflect.Message { - mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[12] +func (x *RenderResponse) ProtoReflect() protoreflect.Message { + mi := &file_api_tokenizerpb_tokenizer_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -840,69 +646,47 @@ func (x *InitializeTokenizerResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use InitializeTokenizerResponse.ProtoReflect.Descriptor instead. -func (*InitializeTokenizerResponse) Descriptor() ([]byte, []int) { - return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{12} +// Deprecated: Use RenderResponse.ProtoReflect.Descriptor instead. +func (*RenderResponse) Descriptor() ([]byte, []int) { + return file_api_tokenizerpb_tokenizer_proto_rawDescGZIP(), []int{8} } -func (x *InitializeTokenizerResponse) GetSuccess() bool { +func (x *RenderResponse) GetInputIds() []uint32 { + if x != nil { + return x.InputIds + } + return nil +} + +func (x *RenderResponse) GetSuccess() bool { if x != nil { return x.Success } return false } -func (x *InitializeTokenizerResponse) GetErrorMessage() string { +func (x *RenderResponse) GetErrorMessage() string { if x != nil { return x.ErrorMessage } return "" } +func (x *RenderResponse) GetOffsetPairs() []uint32 { + if x != nil { + return x.OffsetPairs + } + return nil +} + var File_api_tokenizerpb_tokenizer_proto protoreflect.FileDescriptor const file_api_tokenizerpb_tokenizer_proto_rawDesc = "" + "\n" + - "\x1fapi/tokenizerpb/tokenizer.proto\x12\ftokenization\"t\n" + - "\x0fTokenizeRequest\x12\x14\n" + - "\x05input\x18\x01 \x01(\tR\x05input\x12\x1d\n" + - "\n" + - "model_name\x18\x02 \x01(\tR\tmodelName\x12,\n" + - "\x12add_special_tokens\x18\x03 \x01(\bR\x10addSpecialTokens\"\x91\x01\n" + - "\x10TokenizeResponse\x12\x1b\n" + - "\tinput_ids\x18\x01 \x03(\rR\binputIds\x12\x18\n" + - "\asuccess\x18\x02 \x01(\bR\asuccess\x12#\n" + - "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\x12!\n" + - "\foffset_pairs\x18\x04 \x03(\rR\voffsetPairs\"I\n" + - "\x10ConversationTurn\x125\n" + - "\bmessages\x18\x01 \x03(\v2\x19.tokenization.ChatMessageR\bmessages\"\x87\x05\n" + - "\x13ChatTemplateRequest\x12M\n" + - "\x12conversation_turns\x18\x01 \x03(\v2\x1e.tokenization.ConversationTurnR\x11conversationTurns\x123\n" + - "\x05tools\x18\x02 \x03(\v2\x1d.tokenization.ToolDescriptionR\x05tools\x124\n" + - "\tdocuments\x18\x03 \x03(\v2\x16.tokenization.DocumentR\tdocuments\x12#\n" + - "\rchat_template\x18\x04 \x01(\tR\fchatTemplate\x12?\n" + - "\x1creturn_assistant_tokens_mask\x18\x05 \x01(\bR\x19returnAssistantTokensMask\x124\n" + - "\x16continue_final_message\x18\x06 \x01(\bR\x14continueFinalMessage\x122\n" + - "\x15add_generation_prompt\x18\a \x01(\bR\x13addGenerationPrompt\x12k\n" + - "\x14chat_template_kwargs\x18\b \x03(\v29.tokenization.ChatTemplateRequest.ChatTemplateKwargsEntryR\x12chatTemplateKwargs\x12\x1d\n" + - "\n" + - "model_name\x18\t \x01(\tR\tmodelName\x1aZ\n" + - "\x17ChatTemplateKwargsEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01\";\n" + + "\x1fapi/tokenizerpb/tokenizer.proto\x12\ftokenization\";\n" + "\vChatMessage\x12\x12\n" + "\x04role\x18\x01 \x01(\tR\x04role\x12\x18\n" + - "\acontent\x18\x02 \x01(\tR\acontent\"\x9c\x01\n" + - "\x0fToolDescription\x12;\n" + - "\x04tool\x18\x01 \x03(\v2'.tokenization.ToolDescription.ToolEntryR\x04tool\x1aL\n" + - "\tToolEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01\"\x9e\x01\n" + - "\bDocument\x12@\n" + - "\bdocument\x18\x01 \x03(\v2$.tokenization.Document.DocumentEntryR\bdocument\x1aP\n" + - "\rDocumentEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01\"\xf5\x01\n" + + "\acontent\x18\x02 \x01(\tR\acontent\"\xf5\x01\n" + "\x05Value\x12#\n" + "\fstring_value\x18\x01 \x01(\tH\x00R\vstringValue\x12#\n" + "\fnumber_value\x18\x02 \x01(\x01H\x00R\vnumberValue\x12\x1f\n" + @@ -918,23 +702,52 @@ const file_api_tokenizerpb_tokenizer_proto_rawDesc = "" + "\x06fields\x18\x01 \x03(\v2%.tokenization.StructValue.FieldsEntryR\x06fields\x1aN\n" + "\vFieldsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01\"~\n" + - "\x14ChatTemplateResponse\x12'\n" + - "\x0frendered_prompt\x18\x01 \x01(\tR\x0erenderedPrompt\x12\x18\n" + - "\asuccess\x18\x02 \x01(\bR\asuccess\x12#\n" + - "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\"\x98\x01\n" + - "\x1aInitializeTokenizerRequest\x12\x1d\n" + - "\n" + - "model_name\x18\x01 \x01(\tR\tmodelName\x12'\n" + - "\x0fenable_thinking\x18\x02 \x01(\bR\x0eenableThinking\x122\n" + - "\x15add_generation_prompt\x18\x03 \x01(\bR\x13addGenerationPrompt\"\\\n" + + "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01\"\xd9\x01\n" + + "\x1aInitializeTokenizerRequest\x12\x19\n" + + "\bis_local\x18\x01 \x01(\bR\aisLocal\x12\x14\n" + + "\x05model\x18\x02 \x01(\tR\x05model\x12\x1f\n" + + "\brevision\x18\x03 \x01(\tH\x00R\brevision\x88\x01\x01\x12\x19\n" + + "\x05token\x18\x04 \x01(\tH\x01R\x05token\x88\x01\x01\x12&\n" + + "\fdownload_dir\x18\x05 \x01(\tH\x02R\vdownloadDir\x88\x01\x01B\v\n" + + "\t_revisionB\b\n" + + "\x06_tokenB\x0f\n" + + "\r_download_dir\"\\\n" + "\x1bInitializeTokenizerResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12#\n" + - "\rerror_message\x18\x02 \x01(\tR\ferrorMessage2\xa9\x02\n" + - "\x13TokenizationService\x12I\n" + - "\bTokenize\x12\x1d.tokenization.TokenizeRequest\x1a\x1e.tokenization.TokenizeResponse\x12[\n" + - "\x12RenderChatTemplate\x12!.tokenization.ChatTemplateRequest\x1a\".tokenization.ChatTemplateResponse\x12j\n" + - "\x13InitializeTokenizer\x12(.tokenization.InitializeTokenizerRequest\x1a).tokenization.InitializeTokenizerResponseB=Z;github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpbb\x06proto3" + "\rerror_message\x18\x02 \x01(\tR\ferrorMessage\"p\n" + + "\rRenderRequest\x12\x12\n" + + "\x04text\x18\x01 \x01(\tR\x04text\x12\x1d\n" + + "\n" + + "model_name\x18\x02 \x01(\tR\tmodelName\x12,\n" + + "\x12add_special_tokens\x18\x03 \x01(\bR\x10addSpecialTokens\"\xe2\x05\n" + + "\x11RenderChatRequest\x12=\n" + + "\fconversation\x18\x01 \x03(\v2\x19.tokenization.ChatMessageR\fconversation\x12)\n" + + "\x05tools\x18\x02 \x03(\v2\x13.tokenization.ValueR\x05tools\x121\n" + + "\tdocuments\x18\x03 \x03(\v2\x13.tokenization.ValueR\tdocuments\x12(\n" + + "\rchat_template\x18\x04 \x01(\tH\x00R\fchatTemplate\x88\x01\x01\x12D\n" + + "\x1creturn_assistant_tokens_mask\x18\x05 \x01(\bH\x01R\x19returnAssistantTokensMask\x88\x01\x01\x129\n" + + "\x16continue_final_message\x18\x06 \x01(\bH\x02R\x14continueFinalMessage\x88\x01\x01\x127\n" + + "\x15add_generation_prompt\x18\a \x01(\bH\x03R\x13addGenerationPrompt\x88\x01\x01\x12i\n" + + "\x14chat_template_kwargs\x18\b \x03(\v27.tokenization.RenderChatRequest.ChatTemplateKwargsEntryR\x12chatTemplateKwargs\x12\x1d\n" + + "\n" + + "model_name\x18\t \x01(\tR\tmodelName\x1aZ\n" + + "\x17ChatTemplateKwargsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + + "\x05value\x18\x02 \x01(\v2\x13.tokenization.ValueR\x05value:\x028\x01B\x10\n" + + "\x0e_chat_templateB\x1f\n" + + "\x1d_return_assistant_tokens_maskB\x19\n" + + "\x17_continue_final_messageB\x18\n" + + "\x16_add_generation_prompt\"\x8f\x01\n" + + "\x0eRenderResponse\x12\x1b\n" + + "\tinput_ids\x18\x01 \x03(\rR\binputIds\x12\x18\n" + + "\asuccess\x18\x02 \x01(\bR\asuccess\x12#\n" + + "\rerror_message\x18\x03 \x01(\tR\ferrorMessage\x12!\n" + + "\foffset_pairs\x18\x04 \x03(\rR\voffsetPairs2\x93\x02\n" + + "\x13TokenizationService\x12j\n" + + "\x13InitializeTokenizer\x12(.tokenization.InitializeTokenizerRequest\x1a).tokenization.InitializeTokenizerResponse\x12C\n" + + "\x06Render\x12\x1b.tokenization.RenderRequest\x1a\x1c.tokenization.RenderResponse\x12K\n" + + "\n" + + "RenderChat\x12\x1f.tokenization.RenderChatRequest\x1a\x1c.tokenization.RenderResponseB=Z;github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpbb\x06proto3" var ( file_api_tokenizerpb_tokenizer_proto_rawDescOnce sync.Once @@ -948,53 +761,42 @@ func file_api_tokenizerpb_tokenizer_proto_rawDescGZIP() []byte { return file_api_tokenizerpb_tokenizer_proto_rawDescData } -var file_api_tokenizerpb_tokenizer_proto_msgTypes = make([]protoimpl.MessageInfo, 17) +var file_api_tokenizerpb_tokenizer_proto_msgTypes = make([]protoimpl.MessageInfo, 11) var file_api_tokenizerpb_tokenizer_proto_goTypes = []any{ - (*TokenizeRequest)(nil), // 0: tokenization.TokenizeRequest - (*TokenizeResponse)(nil), // 1: tokenization.TokenizeResponse - (*ConversationTurn)(nil), // 2: tokenization.ConversationTurn - (*ChatTemplateRequest)(nil), // 3: tokenization.ChatTemplateRequest - (*ChatMessage)(nil), // 4: tokenization.ChatMessage - (*ToolDescription)(nil), // 5: tokenization.ToolDescription - (*Document)(nil), // 6: tokenization.Document - (*Value)(nil), // 7: tokenization.Value - (*ListValue)(nil), // 8: tokenization.ListValue - (*StructValue)(nil), // 9: tokenization.StructValue - (*ChatTemplateResponse)(nil), // 10: tokenization.ChatTemplateResponse - (*InitializeTokenizerRequest)(nil), // 11: tokenization.InitializeTokenizerRequest - (*InitializeTokenizerResponse)(nil), // 12: tokenization.InitializeTokenizerResponse - nil, // 13: tokenization.ChatTemplateRequest.ChatTemplateKwargsEntry - nil, // 14: tokenization.ToolDescription.ToolEntry - nil, // 15: tokenization.Document.DocumentEntry - nil, // 16: tokenization.StructValue.FieldsEntry + (*ChatMessage)(nil), // 0: tokenization.ChatMessage + (*Value)(nil), // 1: tokenization.Value + (*ListValue)(nil), // 2: tokenization.ListValue + (*StructValue)(nil), // 3: tokenization.StructValue + (*InitializeTokenizerRequest)(nil), // 4: tokenization.InitializeTokenizerRequest + (*InitializeTokenizerResponse)(nil), // 5: tokenization.InitializeTokenizerResponse + (*RenderRequest)(nil), // 6: tokenization.RenderRequest + (*RenderChatRequest)(nil), // 7: tokenization.RenderChatRequest + (*RenderResponse)(nil), // 8: tokenization.RenderResponse + nil, // 9: tokenization.StructValue.FieldsEntry + nil, // 10: tokenization.RenderChatRequest.ChatTemplateKwargsEntry } var file_api_tokenizerpb_tokenizer_proto_depIdxs = []int32{ - 4, // 0: tokenization.ConversationTurn.messages:type_name -> tokenization.ChatMessage - 2, // 1: tokenization.ChatTemplateRequest.conversation_turns:type_name -> tokenization.ConversationTurn - 5, // 2: tokenization.ChatTemplateRequest.tools:type_name -> tokenization.ToolDescription - 6, // 3: tokenization.ChatTemplateRequest.documents:type_name -> tokenization.Document - 13, // 4: tokenization.ChatTemplateRequest.chat_template_kwargs:type_name -> tokenization.ChatTemplateRequest.ChatTemplateKwargsEntry - 14, // 5: tokenization.ToolDescription.tool:type_name -> tokenization.ToolDescription.ToolEntry - 15, // 6: tokenization.Document.document:type_name -> tokenization.Document.DocumentEntry - 8, // 7: tokenization.Value.list_value:type_name -> tokenization.ListValue - 9, // 8: tokenization.Value.struct_value:type_name -> tokenization.StructValue - 7, // 9: tokenization.ListValue.values:type_name -> tokenization.Value - 16, // 10: tokenization.StructValue.fields:type_name -> tokenization.StructValue.FieldsEntry - 7, // 11: tokenization.ChatTemplateRequest.ChatTemplateKwargsEntry.value:type_name -> tokenization.Value - 7, // 12: tokenization.ToolDescription.ToolEntry.value:type_name -> tokenization.Value - 7, // 13: tokenization.Document.DocumentEntry.value:type_name -> tokenization.Value - 7, // 14: tokenization.StructValue.FieldsEntry.value:type_name -> tokenization.Value - 0, // 15: tokenization.TokenizationService.Tokenize:input_type -> tokenization.TokenizeRequest - 3, // 16: tokenization.TokenizationService.RenderChatTemplate:input_type -> tokenization.ChatTemplateRequest - 11, // 17: tokenization.TokenizationService.InitializeTokenizer:input_type -> tokenization.InitializeTokenizerRequest - 1, // 18: tokenization.TokenizationService.Tokenize:output_type -> tokenization.TokenizeResponse - 10, // 19: tokenization.TokenizationService.RenderChatTemplate:output_type -> tokenization.ChatTemplateResponse - 12, // 20: tokenization.TokenizationService.InitializeTokenizer:output_type -> tokenization.InitializeTokenizerResponse - 18, // [18:21] is the sub-list for method output_type - 15, // [15:18] is the sub-list for method input_type - 15, // [15:15] is the sub-list for extension type_name - 15, // [15:15] is the sub-list for extension extendee - 0, // [0:15] is the sub-list for field type_name + 2, // 0: tokenization.Value.list_value:type_name -> tokenization.ListValue + 3, // 1: tokenization.Value.struct_value:type_name -> tokenization.StructValue + 1, // 2: tokenization.ListValue.values:type_name -> tokenization.Value + 9, // 3: tokenization.StructValue.fields:type_name -> tokenization.StructValue.FieldsEntry + 0, // 4: tokenization.RenderChatRequest.conversation:type_name -> tokenization.ChatMessage + 1, // 5: tokenization.RenderChatRequest.tools:type_name -> tokenization.Value + 1, // 6: tokenization.RenderChatRequest.documents:type_name -> tokenization.Value + 10, // 7: tokenization.RenderChatRequest.chat_template_kwargs:type_name -> tokenization.RenderChatRequest.ChatTemplateKwargsEntry + 1, // 8: tokenization.StructValue.FieldsEntry.value:type_name -> tokenization.Value + 1, // 9: tokenization.RenderChatRequest.ChatTemplateKwargsEntry.value:type_name -> tokenization.Value + 4, // 10: tokenization.TokenizationService.InitializeTokenizer:input_type -> tokenization.InitializeTokenizerRequest + 6, // 11: tokenization.TokenizationService.Render:input_type -> tokenization.RenderRequest + 7, // 12: tokenization.TokenizationService.RenderChat:input_type -> tokenization.RenderChatRequest + 5, // 13: tokenization.TokenizationService.InitializeTokenizer:output_type -> tokenization.InitializeTokenizerResponse + 8, // 14: tokenization.TokenizationService.Render:output_type -> tokenization.RenderResponse + 8, // 15: tokenization.TokenizationService.RenderChat:output_type -> tokenization.RenderResponse + 13, // [13:16] is the sub-list for method output_type + 10, // [10:13] is the sub-list for method input_type + 10, // [10:10] is the sub-list for extension type_name + 10, // [10:10] is the sub-list for extension extendee + 0, // [0:10] is the sub-list for field type_name } func init() { file_api_tokenizerpb_tokenizer_proto_init() } @@ -1002,20 +804,22 @@ func file_api_tokenizerpb_tokenizer_proto_init() { if File_api_tokenizerpb_tokenizer_proto != nil { return } - file_api_tokenizerpb_tokenizer_proto_msgTypes[7].OneofWrappers = []any{ + file_api_tokenizerpb_tokenizer_proto_msgTypes[1].OneofWrappers = []any{ (*Value_StringValue)(nil), (*Value_NumberValue)(nil), (*Value_BoolValue)(nil), (*Value_ListValue)(nil), (*Value_StructValue)(nil), } + file_api_tokenizerpb_tokenizer_proto_msgTypes[4].OneofWrappers = []any{} + file_api_tokenizerpb_tokenizer_proto_msgTypes[7].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_tokenizerpb_tokenizer_proto_rawDesc), len(file_api_tokenizerpb_tokenizer_proto_rawDesc)), NumEnums: 0, - NumMessages: 17, + NumMessages: 11, NumExtensions: 0, NumServices: 1, }, diff --git a/api/tokenizerpb/tokenizer.proto b/api/tokenizerpb/tokenizer.proto index 56fa0561..889d8402 100644 --- a/api/tokenizerpb/tokenizer.proto +++ b/api/tokenizerpb/tokenizer.proto @@ -18,56 +18,12 @@ package tokenization; option go_package = "github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpb"; -// TokenizeRequest represents a request to tokenize a text input -message TokenizeRequest { - string input = 1; // The text input to tokenize - string model_name = 2; // The name of the model to use for tokenization - bool add_special_tokens = 3; // Whether to add special tokens during tokenization -} - -// TokenizeResponse represents the response from tokenization -message TokenizeResponse { - repeated uint32 input_ids = 1; // Token IDs for the input - bool success = 2; // Whether the request was successful - string error_message = 3; // Error message if the request failed - // Direct array of [start, end] pairs - repeated uint32 offset_pairs = 4; // Flattened array of [start, end, start, end, ...] -} - - -// ConversationTurn represents a single turn in a conversation (a single message or multiple messages per turn) -message ConversationTurn { - repeated ChatMessage messages = 1; // The messages in this turn -} - -// ChatTemplateRequest represents a request to render a chat template -message ChatTemplateRequest { - repeated ConversationTurn conversation_turns = 1; // The conversation turns (batches of messages) - repeated ToolDescription tools = 2; // Tools available to the conversation - repeated Document documents = 3; // Documents related to the conversation - string chat_template = 4; // The chat template to use - bool return_assistant_tokens_mask = 5; // Whether to return assistant token mask - bool continue_final_message = 6; // Whether to continue the final message - bool add_generation_prompt = 7; // Whether to add generation prompt - map chat_template_kwargs = 8; // Additional chat template arguments - string model_name = 9; // The name of the model to use for tokenization -} - // ChatMessage represents a single message in a conversation message ChatMessage { string role = 1; // Role of the message (e.g., "user", "assistant", "system") string content = 2; // Content of the message } -// ToolDescription represents a description of a tool -message ToolDescription { - map tool = 1; // Tool definition -} - -// Document represents a document -message Document { - map document = 1; // Document definition -} // Value represents a generic value that can be string, number, bool, or list message Value { @@ -90,18 +46,13 @@ message StructValue { map fields = 1; } -// ChatTemplateResponse represents the response from rendering a chat template -message ChatTemplateResponse { - string rendered_prompt = 1; // The rendered chat template prompt - bool success = 2; // Whether the request was successful - string error_message = 3; // Error message if the request failed -} - // InitializeTokenizerRequest represents a request to initialize tokenizer for a model message InitializeTokenizerRequest { - string model_name = 1; // The name of the model to initialize - bool enable_thinking = 2; // Whether to enable thinking tokens - bool add_generation_prompt = 3; // Whether to add generation prompt + bool is_local = 1; // Whether the model is local (default: true) + string model = 2; // The model ID or path (HF model ID, local directory path, or path to tokenizer file) + optional string revision = 3; // Model revision (optional) + optional string token = 4; // Hugging Face token for private models (optional) + optional string download_dir = 5; // Directory to download the model (optional) } // InitializeTokenizerResponse represents the response from tokenizer initialization @@ -110,14 +61,43 @@ message InitializeTokenizerResponse { string error_message = 2; // Error message if initialization failed } -// TokenizationService defines the gRPC service for tokenization -service TokenizationService { - // Tokenize converts a text input to token IDs - rpc Tokenize(TokenizeRequest) returns (TokenizeResponse); +// RenderRequest represents a request to render (tokenize) a text input +message RenderRequest { + string text = 1; // The text input to render/tokenize + string model_name = 2; // The name of the model to use for tokenization + bool add_special_tokens = 3; // Whether to add special tokens during tokenization +} - // RenderChatTemplate renders a chat template with the given messages - rpc RenderChatTemplate(ChatTemplateRequest) returns (ChatTemplateResponse); +// RenderChatRequest represents a request to render a chat conversation +message RenderChatRequest { + repeated ChatMessage conversation = 1; // The conversation messages (array of role/content pairs) + repeated Value tools = 2; // Tools available to the conversation (arbitrary values) + repeated Value documents = 3; // Documents related to the conversation (arbitrary values) + optional string chat_template = 4; // The chat template to use + optional bool return_assistant_tokens_mask = 5; // Whether to return assistant token mask + optional bool continue_final_message = 6; // Whether to continue the final message + optional bool add_generation_prompt = 7; // Whether to add generation prompt + map chat_template_kwargs = 8; // Additional chat template arguments + string model_name = 9; // The name of the model to use for tokenization +} +// RenderResponse represents the response from rendering/tokenizing +message RenderResponse { + repeated uint32 input_ids = 1; // Token IDs for the input + bool success = 2; // Whether the request was successful + string error_message = 3; // Error message if the request failed + // Direct array of [start, end] pairs + repeated uint32 offset_pairs = 4; // Flattened array of [start, end, start, end, ...] +} + +// TokenizationService defines the gRPC service for tokenization +service TokenizationService { // InitializeTokenizer initializes the tokenizer for a specific model rpc InitializeTokenizer(InitializeTokenizerRequest) returns (InitializeTokenizerResponse); + + // Render renders (tokenizes) a text input + rpc Render(RenderRequest) returns (RenderResponse); + + // RenderChat renders a chat conversation to tokens and offsets + rpc RenderChat(RenderChatRequest) returns (RenderResponse); } \ No newline at end of file diff --git a/api/tokenizerpb/tokenizer_grpc.pb.go b/api/tokenizerpb/tokenizer_grpc.pb.go index cfc03070..ec502e8f 100644 --- a/api/tokenizerpb/tokenizer_grpc.pb.go +++ b/api/tokenizerpb/tokenizer_grpc.pb.go @@ -33,9 +33,9 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - TokenizationService_Tokenize_FullMethodName = "/tokenization.TokenizationService/Tokenize" - TokenizationService_RenderChatTemplate_FullMethodName = "/tokenization.TokenizationService/RenderChatTemplate" TokenizationService_InitializeTokenizer_FullMethodName = "/tokenization.TokenizationService/InitializeTokenizer" + TokenizationService_Render_FullMethodName = "/tokenization.TokenizationService/Render" + TokenizationService_RenderChat_FullMethodName = "/tokenization.TokenizationService/RenderChat" ) // TokenizationServiceClient is the client API for TokenizationService service. @@ -44,12 +44,12 @@ const ( // // TokenizationService defines the gRPC service for tokenization type TokenizationServiceClient interface { - // Tokenize converts a text input to token IDs - Tokenize(ctx context.Context, in *TokenizeRequest, opts ...grpc.CallOption) (*TokenizeResponse, error) - // RenderChatTemplate renders a chat template with the given messages - RenderChatTemplate(ctx context.Context, in *ChatTemplateRequest, opts ...grpc.CallOption) (*ChatTemplateResponse, error) // InitializeTokenizer initializes the tokenizer for a specific model InitializeTokenizer(ctx context.Context, in *InitializeTokenizerRequest, opts ...grpc.CallOption) (*InitializeTokenizerResponse, error) + // Render renders (tokenizes) a text input + Render(ctx context.Context, in *RenderRequest, opts ...grpc.CallOption) (*RenderResponse, error) + // RenderChat renders a chat conversation to tokens and offsets + RenderChat(ctx context.Context, in *RenderChatRequest, opts ...grpc.CallOption) (*RenderResponse, error) } type tokenizationServiceClient struct { @@ -60,30 +60,30 @@ func NewTokenizationServiceClient(cc grpc.ClientConnInterface) TokenizationServi return &tokenizationServiceClient{cc} } -func (c *tokenizationServiceClient) Tokenize(ctx context.Context, in *TokenizeRequest, opts ...grpc.CallOption) (*TokenizeResponse, error) { +func (c *tokenizationServiceClient) InitializeTokenizer(ctx context.Context, in *InitializeTokenizerRequest, opts ...grpc.CallOption) (*InitializeTokenizerResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(TokenizeResponse) - err := c.cc.Invoke(ctx, TokenizationService_Tokenize_FullMethodName, in, out, cOpts...) + out := new(InitializeTokenizerResponse) + err := c.cc.Invoke(ctx, TokenizationService_InitializeTokenizer_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *tokenizationServiceClient) RenderChatTemplate(ctx context.Context, in *ChatTemplateRequest, opts ...grpc.CallOption) (*ChatTemplateResponse, error) { +func (c *tokenizationServiceClient) Render(ctx context.Context, in *RenderRequest, opts ...grpc.CallOption) (*RenderResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(ChatTemplateResponse) - err := c.cc.Invoke(ctx, TokenizationService_RenderChatTemplate_FullMethodName, in, out, cOpts...) + out := new(RenderResponse) + err := c.cc.Invoke(ctx, TokenizationService_Render_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *tokenizationServiceClient) InitializeTokenizer(ctx context.Context, in *InitializeTokenizerRequest, opts ...grpc.CallOption) (*InitializeTokenizerResponse, error) { +func (c *tokenizationServiceClient) RenderChat(ctx context.Context, in *RenderChatRequest, opts ...grpc.CallOption) (*RenderResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(InitializeTokenizerResponse) - err := c.cc.Invoke(ctx, TokenizationService_InitializeTokenizer_FullMethodName, in, out, cOpts...) + out := new(RenderResponse) + err := c.cc.Invoke(ctx, TokenizationService_RenderChat_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -96,12 +96,12 @@ func (c *tokenizationServiceClient) InitializeTokenizer(ctx context.Context, in // // TokenizationService defines the gRPC service for tokenization type TokenizationServiceServer interface { - // Tokenize converts a text input to token IDs - Tokenize(context.Context, *TokenizeRequest) (*TokenizeResponse, error) - // RenderChatTemplate renders a chat template with the given messages - RenderChatTemplate(context.Context, *ChatTemplateRequest) (*ChatTemplateResponse, error) // InitializeTokenizer initializes the tokenizer for a specific model InitializeTokenizer(context.Context, *InitializeTokenizerRequest) (*InitializeTokenizerResponse, error) + // Render renders (tokenizes) a text input + Render(context.Context, *RenderRequest) (*RenderResponse, error) + // RenderChat renders a chat conversation to tokens and offsets + RenderChat(context.Context, *RenderChatRequest) (*RenderResponse, error) mustEmbedUnimplementedTokenizationServiceServer() } @@ -112,15 +112,15 @@ type TokenizationServiceServer interface { // pointer dereference when methods are called. type UnimplementedTokenizationServiceServer struct{} -func (UnimplementedTokenizationServiceServer) Tokenize(context.Context, *TokenizeRequest) (*TokenizeResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Tokenize not implemented") -} -func (UnimplementedTokenizationServiceServer) RenderChatTemplate(context.Context, *ChatTemplateRequest) (*ChatTemplateResponse, error) { - return nil, status.Error(codes.Unimplemented, "method RenderChatTemplate not implemented") -} func (UnimplementedTokenizationServiceServer) InitializeTokenizer(context.Context, *InitializeTokenizerRequest) (*InitializeTokenizerResponse, error) { return nil, status.Error(codes.Unimplemented, "method InitializeTokenizer not implemented") } +func (UnimplementedTokenizationServiceServer) Render(context.Context, *RenderRequest) (*RenderResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Render not implemented") +} +func (UnimplementedTokenizationServiceServer) RenderChat(context.Context, *RenderChatRequest) (*RenderResponse, error) { + return nil, status.Error(codes.Unimplemented, "method RenderChat not implemented") +} func (UnimplementedTokenizationServiceServer) mustEmbedUnimplementedTokenizationServiceServer() {} func (UnimplementedTokenizationServiceServer) testEmbeddedByValue() {} @@ -142,56 +142,56 @@ func RegisterTokenizationServiceServer(s grpc.ServiceRegistrar, srv Tokenization s.RegisterService(&TokenizationService_ServiceDesc, srv) } -func _TokenizationService_Tokenize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(TokenizeRequest) +func _TokenizationService_InitializeTokenizer_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitializeTokenizerRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(TokenizationServiceServer).Tokenize(ctx, in) + return srv.(TokenizationServiceServer).InitializeTokenizer(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: TokenizationService_Tokenize_FullMethodName, + FullMethod: TokenizationService_InitializeTokenizer_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(TokenizationServiceServer).Tokenize(ctx, req.(*TokenizeRequest)) + return srv.(TokenizationServiceServer).InitializeTokenizer(ctx, req.(*InitializeTokenizerRequest)) } return interceptor(ctx, in, info, handler) } -func _TokenizationService_RenderChatTemplate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ChatTemplateRequest) +func _TokenizationService_Render_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RenderRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(TokenizationServiceServer).RenderChatTemplate(ctx, in) + return srv.(TokenizationServiceServer).Render(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: TokenizationService_RenderChatTemplate_FullMethodName, + FullMethod: TokenizationService_Render_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(TokenizationServiceServer).RenderChatTemplate(ctx, req.(*ChatTemplateRequest)) + return srv.(TokenizationServiceServer).Render(ctx, req.(*RenderRequest)) } return interceptor(ctx, in, info, handler) } -func _TokenizationService_InitializeTokenizer_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(InitializeTokenizerRequest) +func _TokenizationService_RenderChat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RenderChatRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(TokenizationServiceServer).InitializeTokenizer(ctx, in) + return srv.(TokenizationServiceServer).RenderChat(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: TokenizationService_InitializeTokenizer_FullMethodName, + FullMethod: TokenizationService_RenderChat_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(TokenizationServiceServer).InitializeTokenizer(ctx, req.(*InitializeTokenizerRequest)) + return srv.(TokenizationServiceServer).RenderChat(ctx, req.(*RenderChatRequest)) } return interceptor(ctx, in, info, handler) } @@ -204,16 +204,16 @@ var TokenizationService_ServiceDesc = grpc.ServiceDesc{ HandlerType: (*TokenizationServiceServer)(nil), Methods: []grpc.MethodDesc{ { - MethodName: "Tokenize", - Handler: _TokenizationService_Tokenize_Handler, + MethodName: "InitializeTokenizer", + Handler: _TokenizationService_InitializeTokenizer_Handler, }, { - MethodName: "RenderChatTemplate", - Handler: _TokenizationService_RenderChatTemplate_Handler, + MethodName: "Render", + Handler: _TokenizationService_Render_Handler, }, { - MethodName: "InitializeTokenizer", - Handler: _TokenizationService_InitializeTokenizer_Handler, + MethodName: "RenderChat", + Handler: _TokenizationService_RenderChat_Handler, }, }, Streams: []grpc.StreamDesc{}, diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index 18525b31..7a22f341 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -62,12 +62,6 @@ const ( envExternalTokenization = "EXTERNAL_TOKENIZATION" ) -// ChatCompletionsRequest holds the fields needed for chat-completions rendering. -type ChatCompletionsRequest struct { - Model string `json:"model"` - *types.RenderChatRequest -} - func main() { baseLogger := zap.New(zap.UseDevMode(true)) log.SetLogger(baseLogger) @@ -319,13 +313,20 @@ func setupUnifiedHTTPEndpoints( return } - var req ChatCompletionsRequest + var req struct { + Model string `json:"model"` + Messages []types.Conversation `json:"messages"` + } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } - pods, err := kvCacheIndexer.GetPodScores(ctx, req.RenderChatRequest, "", req.Model, nil) + renderChatReq := &types.RenderChatRequest{ + Conversation: req.Messages, + } + + pods, err := kvCacheIndexer.GetPodScores(ctx, renderChatReq, "", req.Model, nil) if err != nil { http.Error(w, fmt.Sprintf("Failed to get score request: %v", err), http.StatusInternalServerError) return diff --git a/examples/kv_events/online_uds/main.go b/examples/kv_events/online_uds/main.go new file mode 100644 index 00000000..56f4d8fe --- /dev/null +++ b/examples/kv_events/online_uds/main.go @@ -0,0 +1,316 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/llm-d/llm-d-kv-cache/examples/testdata" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + "github.com/llm-d/llm-d-kv-cache/pkg/tokenization/types" + "github.com/prometheus/client_golang/prometheus/promhttp" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + ctrlmetrics "sigs.k8s.io/controller-runtime/pkg/metrics" +) + +const ( + envHFToken = "HF_TOKEN" + envZMQEndpoint = "ZMQ_ENDPOINT" + envZMQTopic = "ZMQ_TOPIC" + envModelName = "MODEL_NAME" + + envPoolConcurrency = "POOL_CONCURRENCY" + defaultZMQEndpoint = "tcp://localhost:5557" + defaultZMQTopic = "kv@" + defaultConcurrency = 4 + + pythonHashSeed = "PYTHONHASHSEED" + blockSizeEnvVar = "BLOCK_SIZE" + + envHTTPPort = "HTTP_PORT" + defaultHTTPPort = "8080" +) + +func main() { + baseLogger := zap.New(zap.UseDevMode(true)) + log.SetLogger(baseLogger) + + ctxBase := log.IntoContext(context.Background(), baseLogger) + ctx, cancel := context.WithCancel(ctxBase) + defer cancel() + + logger := log.FromContext(ctx) + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + logger.Info("Received shutdown signal") + cancel() + }() + + if err := run(ctx); err != nil { + logger.Error(err, "Failed to run unified KV-cache service") + return + } +} + +func run(ctx context.Context) error { + logger := log.FromContext(ctx) + + // Setup KV Cache Indexer + kvCacheIndexer, err := setupKVCacheIndexer(ctx) + if err != nil { + logger.Error(err, "failed to setup KVCacheIndexer") + return err + } + + // Setup events pool + eventsPool := setupEventsPool(ctx, kvCacheIndexer.KVBlockIndex()) + eventsPool.Start(ctx) + logger.Info("Events pool started and listening for ZMQ messages") + + // Setup HTTP server + httpServer := setupUnifiedHTTPEndpoints(ctx, kvCacheIndexer) + + logger.Info("=== Online KV Events Example Started ===") + logger.Info("HTTP server running on http://localhost:8080") + logger.Info("Available endpoints:") + logger.Info(" - POST /score_completions - Score /v1/completions requests") + logger.Info(" - POST /score_chat_completions - Score /v1/chat_completions requests") + logger.Info(" - GET /metrics - Prometheus metrics endpoint") + + // Wait for shutdown + <-ctx.Done() + logger.Info("Shutting down KV-cache service...") + + // Graceful shutdown with timeout + shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 30*time.Second) + defer shutdownCancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + logger.Error(err, "HTTP server shutdown error") + } + + return nil +} + +func getKVCacheIndexerConfig() (*kvcache.Config, error) { + config, err := kvcache.NewDefaultConfig() + if err != nil { + return nil, err + } + + huggingFaceToken := os.Getenv(envHFToken) + if huggingFaceToken != "" { + config.TokenizersPoolConfig.UdsTokenizerConfig.HuggingFaceToken = huggingFaceToken + } + + modelName := os.Getenv(envModelName) + if modelName == "" { + modelName = testdata.ModelName + } + + config.TokenizersPoolConfig.ModelName = modelName + config.KVBlockIndexConfig.EnableMetrics = true + config.KVBlockIndexConfig.MetricsLoggingInterval = 30 * time.Second + + return config, nil +} + +func getTokenProcessorConfig() *kvblock.TokenProcessorConfig { + config := kvblock.DefaultTokenProcessorConfig() + hashSeed := os.Getenv(pythonHashSeed) + if hashSeed != "" { + config.HashSeed = hashSeed + } + + blockSize, err := strconv.Atoi(os.Getenv(blockSizeEnvVar)) + if err == nil && blockSize >= 0 { + config.BlockSize = blockSize + } + return config +} + +func getEventsPoolConfig() *kvevents.Config { + concurrency := defaultConcurrency + if envConcurrency := os.Getenv(envPoolConcurrency); envConcurrency != "" { + if c, err := strconv.Atoi(envConcurrency); err == nil && c > 0 { + concurrency = c + } + } + + zmqEndpoint := os.Getenv(envZMQEndpoint) + if zmqEndpoint == "" { + zmqEndpoint = defaultZMQEndpoint + } + + zmqTopic := os.Getenv(envZMQTopic) + if zmqTopic == "" { + zmqTopic = defaultZMQTopic + } + + return &kvevents.Config{ + Concurrency: concurrency, + ZMQEndpoint: zmqEndpoint, + TopicFilter: zmqTopic, + } +} + +func setupKVCacheIndexer(ctx context.Context) (*kvcache.Indexer, error) { + logger := log.FromContext(ctx) + + cfg, err := getKVCacheIndexerConfig() + if err != nil { + return nil, err + } + + kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, cfg, + kvblock.NewChunkedTokenDatabase(getTokenProcessorConfig())) + if err != nil { + return nil, err + } + + logger.Info("Created Indexer") + + go kvCacheIndexer.Run(ctx) + logger.Info("Started Indexer") + + return kvCacheIndexer, nil +} + +func setupEventsPool(ctx context.Context, kvBlockIndex kvblock.Index) *kvevents.Pool { + logger := log.FromContext(ctx) + + cfg := getEventsPoolConfig() + + logger.Info("Creating events pool", "config", cfg) + tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) + pool := kvevents.NewPool(cfg, kvBlockIndex, tokenProcessor) + + return pool +} + +func setupUnifiedHTTPEndpoints( + ctx context.Context, + kvCacheIndexer *kvcache.Indexer, +) *http.Server { + logger := log.FromContext(ctx) + + mux := http.NewServeMux() + + mux.Handle("/metrics", promhttp.HandlerFor(ctrlmetrics.Registry, promhttp.HandlerOpts{ + EnableOpenMetrics: true, + })) + + mux.HandleFunc("/score_completions", func(w http.ResponseWriter, r *http.Request) { + var req struct { + Prompt string `json:"prompt"` + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + if req.Prompt == "" { + http.Error(w, "field 'prompt' required", http.StatusBadRequest) + return + } + + pods, err := kvCacheIndexer.GetPodScores(ctx, nil, req.Prompt, req.Model, nil) + if err != nil { + http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(pods); err != nil { + logger.Error(err, "failed to encode response") + } + }) + + mux.HandleFunc("/score_chat_completions", func(w http.ResponseWriter, r *http.Request) { + logger.Info("Received request for /score_chat_completions", "body", r.Body) + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Model string `json:"model"` + Messages []types.Conversation `json:"messages"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + renderChatReq := &types.RenderChatRequest{ + Conversation: req.Messages, + } + + pods, err := kvCacheIndexer.GetPodScores(ctx, renderChatReq, "", req.Model, nil) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to get score request: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(pods); err != nil { + logger.Error(err, "Failed to encode score response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + }) + + // Get HTTP port + httpPort := os.Getenv(envHTTPPort) + if httpPort == "" { + httpPort = defaultHTTPPort + } + + server := &http.Server{ + Addr: ":" + httpPort, + Handler: mux, + ReadHeaderTimeout: 20 * time.Second, + ReadTimeout: 1 * time.Minute, + WriteTimeout: 1 * time.Minute, + } + + // Start HTTP server in goroutine + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error(err, "HTTP server error") + } + }() + + return server +} diff --git a/pkg/kvcache/kvblock/instrumented_index.go b/pkg/kvcache/kvblock/instrumented_index.go index 70d6ab21..6b72583d 100644 --- a/pkg/kvcache/kvblock/instrumented_index.go +++ b/pkg/kvcache/kvblock/instrumented_index.go @@ -59,7 +59,14 @@ func (m *instrumentedIndex) Lookup( return nil, err } - go recordHitMetrics(pods) + // Create a deep copy of the pods map to avoid race conditions when accessed by the goroutine + podsCopy := make(map[BlockHash][]PodEntry, len(pods)) + for k, v := range pods { + podsCopy[k] = make([]PodEntry, len(v)) + copy(podsCopy[k], v) + } + + go recordHitMetrics(podsCopy) return pods, nil } diff --git a/pkg/preprocessing/chat_completions/setup.sh b/pkg/preprocessing/chat_completions/setup.sh index 06405ba3..4d4e7383 100755 --- a/pkg/preprocessing/chat_completions/setup.sh +++ b/pkg/preprocessing/chat_completions/setup.sh @@ -1,4 +1,18 @@ #!/bin/bash +# Copyright 2025 The llm-d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # vLLM installation script for macOS/Apple Silicon # https://docs.vllm.ai/en/stable/getting_started/installation/cpu.html diff --git a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py index c63b6be5..03bd92c2 100644 --- a/pkg/preprocessing/chat_completions/tokenizer_wrapper.py +++ b/pkg/preprocessing/chat_completions/tokenizer_wrapper.py @@ -118,6 +118,12 @@ def render_chat(request_json): if tokenizer is None: raise RuntimeError(f"Tokenizer with key {key} not found in cache") + # Handle empty conversation specially + conversation = request.get("conversation", []) + if not conversation: + result = {"input_ids": [], "offset_mapping": []} + return json.dumps(result) + # Get template_vars and spread them as individual arguments template_vars = request.pop("chat_template_kwargs", {}) request.update(template_vars) diff --git a/pkg/tokenization/uds_tokenizer.go b/pkg/tokenization/uds_tokenizer.go index 2f0c7d74..191519e9 100644 --- a/pkg/tokenization/uds_tokenizer.go +++ b/pkg/tokenization/uds_tokenizer.go @@ -29,10 +29,12 @@ import ( ) // UdsTokenizerConfig represents the configuration for the UDS-based tokenizer, -// including the socket file path or TCP address (for testing only). +// including the socket file path and other settings or TCP address (for testing only). type UdsTokenizerConfig struct { - SocketFile string `json:"socketFile"` // UDS socket path (production) or host:port for TCP (testing only) - UseTCP bool `json:"useTCP"` // If true, use TCP instead of UDS (for testing only, default: false) + SocketFile string `json:"socketFile"` // Path to the UDS socket file + HuggingFaceToken string `json:"huggingFaceToken"` // Hugging Face token for private models + TokenizersCacheDir string `json:"tokenizersCacheDir"` // Directory for caching tokenizers + UseTCP bool `json:"useTCP"` // If true, use TCP instead of UDS (for testing only, default: false) } func (cfg *UdsTokenizerConfig) IsEnabled() bool { @@ -46,6 +48,7 @@ type UdsTokenizer struct { model string conn *grpc.ClientConn client tokenizerpb.TokenizationServiceClient + config *UdsTokenizerConfig } const ( @@ -96,6 +99,7 @@ func NewUdsTokenizer(ctx context.Context, config *UdsTokenizerConfig, modelName conn: conn, client: client, model: modelName, + config: config, } // Start a goroutine to monitor the context and close the connection when the context ends @@ -114,11 +118,22 @@ func NewUdsTokenizer(ctx context.Context, config *UdsTokenizerConfig, modelName // initializeTokenizerForModel initializes the tokenizer service for a specific model. func (u *UdsTokenizer) initializeTokenizerForModel(ctx context.Context) error { - // Use default configuration values for now + config := u.config // Access the stored config + + // Use configuration values from the config - align with tokenizer_wrapper.py parameters req := &tokenizerpb.InitializeTokenizerRequest{ - ModelName: u.model, - EnableThinking: false, // Can be made configurable later - AddGenerationPrompt: true, // Can be made configurable later + IsLocal: true, // Default to true per proto definition + Model: u.model, + Token: nil, // Optional - will use environment variable if needed + DownloadDir: nil, // Optional - defaults to HF cache + } + + if config.HuggingFaceToken != "" { + req.Token = &config.HuggingFaceToken + } + + if config.TokenizersCacheDir != "" { + req.DownloadDir = &config.TokenizersCacheDir } // Retry logic with exponential backoff @@ -154,44 +169,49 @@ func (u *UdsTokenizer) initializeTokenizerForModel(ctx context.Context) error { return fmt.Errorf("tokenizer initialization failed after %d attempts: %w", maxRetries, lastErr) } -func (u *UdsTokenizer) Render(prompt string) ([]uint32, []types.Offset, error) { - return u.Encode(prompt, true) +// parseOffsetPairs parses the flattened array of offset pairs [start, end, start, end, ...] +// into a slice of types.Offset structs. +func parseOffsetPairs(offsetPairs []uint32) ([]types.Offset, error) { + var tokenizersOffsets []types.Offset + + if len(offsetPairs) > 0 && len(offsetPairs)%2 == 0 { + // Use offset_pairs field in format [start, end, start, end, ...] + pairCount := len(offsetPairs) / 2 + tokenizersOffsets = make([]types.Offset, pairCount) + for i := 0; i < pairCount; i++ { + start := offsetPairs[2*i] + end := offsetPairs[2*i+1] + tokenizersOffsets[i] = types.Offset{uint(start), uint(end)} + } + } else { + return nil, fmt.Errorf("invalid offset_pairs field in response") + } + + return tokenizersOffsets, nil } -// Encode tokenizes the input string and returns the token IDs and offsets. -func (u *UdsTokenizer) Encode(prompt string, addSpecialTokens bool) ([]uint32, []types.Offset, error) { +func (u *UdsTokenizer) Render(prompt string) ([]uint32, []types.Offset, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) defer cancel() - pbReq := &tokenizerpb.TokenizeRequest{ - Input: prompt, + pbReq := &tokenizerpb.RenderRequest{ + Text: prompt, ModelName: u.model, - AddSpecialTokens: addSpecialTokens, + AddSpecialTokens: true, } - resp, err := u.client.Tokenize(ctx, pbReq) + resp, err := u.client.Render(ctx, pbReq) if err != nil { - return nil, nil, fmt.Errorf("gRPC tokenize request failed: %w", err) + return nil, nil, fmt.Errorf("gRPC render request failed: %w", err) } if !resp.Success { - return nil, nil, fmt.Errorf("tokenization failed: %s", resp.ErrorMessage) + return nil, nil, fmt.Errorf("render failed: %s", resp.ErrorMessage) } - // Use offset_pairs field in format [start, end, start, end, ...] - var tokenizersOffsets []types.Offset - - if len(resp.OffsetPairs) > 0 && len(resp.OffsetPairs)%2 == 0 { - // Use offset_pairs field in format [start, end, start, end, ...] - pairCount := len(resp.OffsetPairs) / 2 - tokenizersOffsets = make([]types.Offset, pairCount) - for i := 0; i < pairCount; i++ { - start := resp.OffsetPairs[2*i] - end := resp.OffsetPairs[2*i+1] - tokenizersOffsets[i] = types.Offset{uint(start), uint(end)} - } - } else { - return nil, nil, fmt.Errorf("invalid offset_pairs field in response") + tokenizersOffsets, err := parseOffsetPairs(resp.OffsetPairs) + if err != nil { + return nil, nil, err } return resp.InputIds, tokenizersOffsets, nil @@ -212,9 +232,6 @@ func (u *UdsTokenizer) RenderChat( Content: msg.Content, }) } - conversationTurns := []*tokenizerpb.ConversationTurn{ - {Messages: messages}, - } // Convert ChatTemplateKWArgs chatTemplateKwargs := make(map[string]*tokenizerpb.Value) @@ -222,26 +239,48 @@ func (u *UdsTokenizer) RenderChat( chatTemplateKwargs[k] = ConvertToProtoValue(v) } - req := &tokenizerpb.ChatTemplateRequest{ - ConversationTurns: conversationTurns, - ChatTemplate: renderReq.ChatTemplate, - ReturnAssistantTokensMask: renderReq.ReturnAssistantTokensMask, - ContinueFinalMessage: renderReq.ContinueFinalMessage, - AddGenerationPrompt: renderReq.AddGenerationPrompt, + // Convert tools from interface{} array to protobuf Value array + tools := make([]*tokenizerpb.Value, 0, len(renderReq.Tools)) + for _, tool := range renderReq.Tools { + tools = append(tools, ConvertToProtoValue(tool)) + } + + // Convert documents from interface{} array to protobuf Value array + documents := make([]*tokenizerpb.Value, 0, len(renderReq.Documents)) + for _, doc := range renderReq.Documents { + documents = append(documents, ConvertToProtoValue(doc)) + } + + req := &tokenizerpb.RenderChatRequest{ + Conversation: messages, + Tools: tools, + Documents: documents, + ReturnAssistantTokensMask: &renderReq.ReturnAssistantTokensMask, + ContinueFinalMessage: &renderReq.ContinueFinalMessage, + AddGenerationPrompt: &renderReq.AddGenerationPrompt, ChatTemplateKwargs: chatTemplateKwargs, ModelName: u.model, } - resp, err := u.client.RenderChatTemplate(ctx, req) + if renderReq.ChatTemplate != "" { + req.ChatTemplate = &renderReq.ChatTemplate + } + + resp, err := u.client.RenderChat(ctx, req) if err != nil { - return nil, nil, fmt.Errorf("gRPC chat-template request failed: %w", err) + return nil, nil, fmt.Errorf("gRPC render-chat request failed: %w", err) } if !resp.Success { - return nil, nil, fmt.Errorf("chat template rendering failed: %s", resp.ErrorMessage) + return nil, nil, fmt.Errorf("render-chat failed: %s", resp.ErrorMessage) } - return u.Encode(resp.RenderedPrompt, false) + tokenizersOffsets, err := parseOffsetPairs(resp.OffsetPairs) + if err != nil { + return nil, nil, err + } + + return resp.InputIds, tokenizersOffsets, nil } // ConvertToProtoValue converts a Go interface{} value to a protobuf Value. diff --git a/pkg/tokenization/uds_tokenizer_test.go b/pkg/tokenization/uds_tokenizer_test.go index eb196445..9092820e 100644 --- a/pkg/tokenization/uds_tokenizer_test.go +++ b/pkg/tokenization/uds_tokenizer_test.go @@ -37,7 +37,7 @@ import ( type mockTokenizationServer struct { tokenizerpb.UnimplementedTokenizationServiceServer initializeError bool - tokenizeError bool + renderError bool chatError bool initialized map[string]bool } @@ -59,26 +59,26 @@ func (m *mockTokenizationServer) InitializeTokenizer( }, nil } - m.initialized[req.ModelName] = true + m.initialized[req.Model] = true return &tokenizerpb.InitializeTokenizerResponse{ Success: true, }, nil } -func (m *mockTokenizationServer) Tokenize( +func (m *mockTokenizationServer) Render( ctx context.Context, - req *tokenizerpb.TokenizeRequest, -) (*tokenizerpb.TokenizeResponse, error) { - if m.tokenizeError { - return &tokenizerpb.TokenizeResponse{ + req *tokenizerpb.RenderRequest, +) (*tokenizerpb.RenderResponse, error) { + if m.renderError { + return &tokenizerpb.RenderResponse{ Success: false, - ErrorMessage: "mock tokenization error", + ErrorMessage: "mock render error", }, nil } // Check if model was initialized (matches real service behavior) if !m.initialized[req.ModelName] { - return &tokenizerpb.TokenizeResponse{ + return &tokenizerpb.RenderResponse{ Success: false, ErrorMessage: fmt.Sprintf("model %s not initialized", req.ModelName), }, nil @@ -86,7 +86,7 @@ func (m *mockTokenizationServer) Tokenize( // Simple deterministic mock tokenization: convert each rune to a token ID // This makes tests more realistic - different inputs produce different tokens - input := req.Input + input := req.Text tokens := make([]uint32, 0, len(input)) offsets := make([]uint32, 0, len(input)*2) @@ -96,7 +96,7 @@ func (m *mockTokenizationServer) Tokenize( offsets = append(offsets, uint32(i), uint32(i+1)) } - return &tokenizerpb.TokenizeResponse{ + return &tokenizerpb.RenderResponse{ InputIds: tokens, Success: true, OffsetPairs: offsets, @@ -104,12 +104,12 @@ func (m *mockTokenizationServer) Tokenize( }, nil } -func (m *mockTokenizationServer) RenderChatTemplate( +func (m *mockTokenizationServer) RenderChat( ctx context.Context, - req *tokenizerpb.ChatTemplateRequest, -) (*tokenizerpb.ChatTemplateResponse, error) { + req *tokenizerpb.RenderChatRequest, +) (*tokenizerpb.RenderResponse, error) { if m.chatError { - return &tokenizerpb.ChatTemplateResponse{ + return &tokenizerpb.RenderResponse{ Success: false, ErrorMessage: "mock chat template error", }, nil @@ -117,7 +117,7 @@ func (m *mockTokenizationServer) RenderChatTemplate( // Check if model was initialized (matches real service behavior) if !m.initialized[req.ModelName] { - return &tokenizerpb.ChatTemplateResponse{ + return &tokenizerpb.RenderResponse{ Success: false, ErrorMessage: fmt.Sprintf("model %s not initialized", req.ModelName), }, nil @@ -125,15 +125,25 @@ func (m *mockTokenizationServer) RenderChatTemplate( // Mock chat template rendering by concatenating messages rendered := "" - for _, turn := range req.ConversationTurns { - for _, msg := range turn.Messages { - rendered += fmt.Sprintf("%s: %s\n", msg.Role, msg.Content) - } + for _, msg := range req.Conversation { + rendered += fmt.Sprintf("%s: %s\n", msg.Role, msg.Content) } - return &tokenizerpb.ChatTemplateResponse{ - RenderedPrompt: rendered, - Success: true, + // Generate tokens from the rendered prompt + tokens := make([]uint32, 0, len(rendered)) + offsets := make([]uint32, 0, len(rendered)*2) + + for i, r := range rendered { + tokens = append(tokens, uint32(r)) + // #nosec G115 -- i is bounded by string length, safe conversion + offsets = append(offsets, uint32(i), uint32(i+1)) + } + + return &tokenizerpb.RenderResponse{ + InputIds: tokens, + Success: true, + OffsetPairs: offsets, + ErrorMessage: "", }, nil } @@ -202,7 +212,7 @@ func (s *UdsTokenizerTestSuite) TearDownSuite() { func (s *UdsTokenizerTestSuite) SetupTest() { // Reset error flags for each test s.mockServer.initializeError = false - s.mockServer.tokenizeError = false + s.mockServer.renderError = false s.mockServer.chatError = false // Clear initialized models to ensure test isolation s.mockServer.initialized = make(map[string]bool) @@ -284,11 +294,11 @@ func (s *UdsTokenizerTestSuite) TestUdsTokenizer_Type() { } func (s *UdsTokenizerTestSuite) TestUdsTokenizer_TokenizeError() { - s.mockServer.tokenizeError = true + s.mockServer.renderError = true _, _, err := s.tokenizer.Render("test") s.Assert().Error(err) - s.Assert().Contains(err.Error(), "tokenization failed") + s.Assert().Contains(err.Error(), "render failed") } func (s *UdsTokenizerTestSuite) TestUdsTokenizer_ChatTemplateError() { @@ -302,7 +312,7 @@ func (s *UdsTokenizerTestSuite) TestUdsTokenizer_ChatTemplateError() { _, _, err := s.tokenizer.RenderChat(renderReq) s.Assert().Error(err) - s.Assert().Contains(err.Error(), "chat template rendering failed") + s.Assert().Contains(err.Error(), "render-chat failed") } // convertFromProtoValue converts a proto Value back to a Go interface{} value. diff --git a/services/uds_tokenizer/Dockerfile b/services/uds_tokenizer/Dockerfile deleted file mode 100644 index 170d6be1..00000000 --- a/services/uds_tokenizer/Dockerfile +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 The llm-d Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Build stage -FROM --platform=$TARGETPLATFORM python:3.12-slim as builder - -# Set build arguments -ARG TARGETPLATFORM - -# Set working directory -WORKDIR /app - -# Copy project metadata and install dependencies -COPY pyproject.toml /app/pyproject.toml -RUN pip install --no-cache-dir . - -# Runtime stage -FROM --platform=$TARGETPLATFORM python:3.12-slim - -# Set build arguments -ARG HEALTH_PORT=8082 - -# Set working directory -WORKDIR /app - -RUN apt-get update && apt-get upgrade -y && apt-get clean && rm -rf /var/lib/apt/lists/* /var/cache/apt/* - -# Copy installed dependencies from build stage -COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages -# Copy executables from build stage -COPY --from=builder /usr/local/bin/ /usr/local/bin/ - -# Copy project files into the image -COPY run_grpc_server.py /app/ -COPY tokenizer_grpc_service.py /app/tokenizer_grpc_service.py -COPY tokenizers/ /app/tokenizers/ -COPY utils/ /app/utils/ -COPY tokenizer_service/ /app/tokenizer_service/ -COPY tokenizerpb/ /app/tokenizerpb/ - -# Create directory for UDS socket -RUN mkdir -p /tmp/tokenizer && chown 65532:65532 /tmp/tokenizer - -# Create tokenizer cache directories and set permissions -ENV TOKENIZERS_DIR=/app/tokenizers -RUN mkdir -p /app/tokenizers && chown -R 65532:65532 /app/tokenizers -# Create and set permissions for ModelScope directory -RUN mkdir -p /.modelscope && chown -R 65532:65532 /.modelscope -# Create and set permissions for Hugging Face cache directory -RUN mkdir -p /.cache && chown -R 65532:65532 /.cache - -# Switch to non-root user -USER 65532:65532 - -# Expose health check port (configurable via build arg) -EXPOSE ${HEALTH_PORT} - -# Startup command: run direct gRPC server -CMD ["python", "/app/run_grpc_server.py"] diff --git a/services/uds_tokenizer/pyproject.toml b/services/uds_tokenizer/pyproject.toml index 887aa113..ecdb2f83 100644 --- a/services/uds_tokenizer/pyproject.toml +++ b/services/uds_tokenizer/pyproject.toml @@ -4,19 +4,15 @@ version = "0.1.0" description = "UDS Tokenizer Service - gRPC tokenization over Unix Domain Socket" requires-python = ">=3.12" dependencies = [ - "pydantic==2.11.7", "shortuuid==1.0.13", - "transformers==4.53.0", - "safetensors==0.5.3", - "Jinja2==3.1.6", "modelscope", "huggingface-hub", "aiohttp==3.9.5", "protobuf==6.31.1", - "tiktoken>=0.7.0", "grpcio==1.76.0", "grpcio-tools==1.76.0", "grpcio-reflection==1.76.0", + "readerwriterlock>=1.0.9", ] [project.optional-dependencies] diff --git a/services/uds_tokenizer/run_grpc_server.py b/services/uds_tokenizer/run_grpc_server.py index 6095b9f0..3cbd4705 100644 --- a/services/uds_tokenizer/run_grpc_server.py +++ b/services/uds_tokenizer/run_grpc_server.py @@ -27,7 +27,6 @@ import sys from aiohttp import web -from tokenizer_service.tokenizer import TokenizerService, TokenizerConfig from tokenizer_grpc_service import create_grpc_server from utils.thread_pool_utils import get_thread_pool @@ -53,7 +52,6 @@ probe_loop = None # Store the probe event loop for later use probe_started_event = threading.Event() # Event to signal when probe server has started current_config = None -tokenizer_service = None tokenizer_ready = False shutdown_event = threading.Event() # Event to signal shutdown @@ -67,19 +65,6 @@ def _signal_handler(signum, frame): signal.signal(signal.SIGINT, _signal_handler) -def initialize_tokenizer(): - """Initialize the tokenizer service without pre-loading a specific model""" - global tokenizer_service, current_config, tokenizer_ready - try: - # Initialize tokenizer service without pre-loading any model - tokenizer_service = TokenizerService() # Empty constructor - tokenizer_ready = True - logging.info("Tokenizer service initialized successfully") - except Exception as e: - logging.error(f"Failed to initialize tokenizer service: {e}") - raise - - async def health_handler(request): """Health check endpoint""" global tokenizer_ready @@ -160,14 +145,7 @@ def run_probe_server(): def run_server(): """Run the synchronous gRPC server with background probe server""" - global tokenizer_service, grpc_server - - # Initialize tokenizer - try: - initialize_tokenizer() - except Exception as e: - logging.error(f"Failed to initialize tokenizer, exiting: {e}") - return + global grpc_server, tokenizer_ready # Remove old socket file if it exists if os.path.exists(UDS_SOCKET_PATH): @@ -177,9 +155,10 @@ def run_server(): os.makedirs(os.path.dirname(UDS_SOCKET_PATH), mode=0o700, exist_ok=True) thread_pool = get_thread_pool() - grpc_server = create_grpc_server(tokenizer_service, UDS_SOCKET_PATH, thread_pool, GRPC_PORT) + grpc_server = create_grpc_server(UDS_SOCKET_PATH, thread_pool, GRPC_PORT) grpc_server.start() logging.info(f"Synchronous gRPC server started on {UDS_SOCKET_PATH}" + (f" and TCP port {GRPC_PORT}" if GRPC_PORT else "")) + tokenizer_ready = True # Start probe server in background start_probe_server_in_background() diff --git a/services/uds_tokenizer/tests/conftest.py b/services/uds_tokenizer/tests/conftest.py index c7b1a95b..57579165 100644 --- a/services/uds_tokenizer/tests/conftest.py +++ b/services/uds_tokenizer/tests/conftest.py @@ -22,7 +22,6 @@ import pytest import tokenizerpb.tokenizer_pb2_grpc as tokenizer_pb2_grpc -from tokenizer_service.tokenizer import TokenizerService from tokenizer_grpc_service import create_grpc_server from utils.thread_pool_utils import get_thread_pool @@ -39,9 +38,9 @@ def test_model() -> str: @pytest.fixture(scope="session") def uds_socket_path() -> Iterator[str]: """Return a unique UDS socket path with cleanup. - + Uses /tmp with a short name to avoid macOS 103-char limit. - """ + """ # Create temp directory - auto-cleanup on exit with tempfile.TemporaryDirectory(prefix="tok-") as socket_dir: socket_path = f"{socket_dir}/uds.sock" @@ -49,14 +48,13 @@ def uds_socket_path() -> Iterator[str]: @pytest.fixture(scope="session") -def tokenizer_service(uds_socket_path: str) -> Iterator[TokenizerService]: - """Provide the TokenizerService instance used by the gRPC server.""" - service = TokenizerService() +def grpc_server(uds_socket_path: str) -> Iterator[None]: + """Start the gRPC server for testing.""" thread_pool = get_thread_pool() - server = create_grpc_server(service, uds_socket_path, thread_pool) + server = create_grpc_server(uds_socket_path, thread_pool) server.start() - yield service + yield # Graceful shutdown with matching timeout stop_future = server.stop(grace=5) @@ -64,20 +62,20 @@ def tokenizer_service(uds_socket_path: str) -> Iterator[TokenizerService]: @pytest.fixture(scope="session") -def grpc_channel(tokenizer_service: TokenizerService, uds_socket_path: str) -> Iterator[grpc.Channel]: +def grpc_channel(grpc_server, uds_socket_path: str) -> Iterator[grpc.Channel]: """Create a gRPC channel connected to the test server. - + Uses wait_for_ready to automatically retry connection until server is ready. """ channel = grpc.insecure_channel(f"unix://{uds_socket_path}") - + # Verify channel can connect by waiting for it to be ready try: grpc.channel_ready_future(channel).result(timeout=10.0) except grpc.FutureTimeoutError: channel.close() raise RuntimeError(f"gRPC channel to {uds_socket_path} not ready within 10s") - + yield channel channel.close() diff --git a/services/uds_tokenizer/tests/test_integration.py b/services/uds_tokenizer/tests/test_integration.py index 1e4e2223..074ac69c 100644 --- a/services/uds_tokenizer/tests/test_integration.py +++ b/services/uds_tokenizer/tests/test_integration.py @@ -28,7 +28,6 @@ import pytest import tokenizerpb.tokenizer_pb2 as tokenizer_pb2 -from tokenizer_service.tokenizer import TokenizerService # --------------------------------------------------------------------------- @@ -42,7 +41,7 @@ class TestInitializeTokenizer: def test_initialize_valid_model(self, grpc_stub, test_model): """InitializeTokenizer succeeds for a valid model.""" resp = grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) assert resp.success assert not resp.error_message @@ -51,7 +50,7 @@ def test_initialize_nonexistent_model(self, grpc_stub): """InitializeTokenizer returns an error for a non-existent model.""" resp = grpc_stub.InitializeTokenizer( tokenizer_pb2.InitializeTokenizerRequest( - model_name="non-existent/model-that-does-not-exist-12345" + model="non-existent/model-that-does-not-exist-12345" ) ) assert not resp.success @@ -60,7 +59,7 @@ def test_initialize_nonexistent_model(self, grpc_stub): def test_initialize_empty_model_name(self, grpc_stub): """InitializeTokenizer handles an empty model name.""" resp = grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name="") + tokenizer_pb2.InitializeTokenizerRequest(model="") ) assert not resp.success @@ -68,30 +67,29 @@ def test_initialize_with_enable_thinking(self, grpc_stub, test_model): """InitializeTokenizer respects the enable_thinking flag.""" resp = grpc_stub.InitializeTokenizer( tokenizer_pb2.InitializeTokenizerRequest( - model_name=test_model, - enable_thinking=True, - add_generation_prompt=True, + model=test_model, + is_local=True, ) ) assert resp.success # --------------------------------------------------------------------------- -# Tokenize +# Render # --------------------------------------------------------------------------- -class TestTokenize: - """Tests for the Tokenize RPC.""" +class TestRender: + """Tests for the Render RPC.""" - def test_tokenize_simple_text(self, grpc_stub, test_model): - """Tokenize returns token IDs for simple text.""" + def test_render_simple_text(self, grpc_stub, test_model): + """Render returns token IDs for simple text.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) - resp = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="Hello, how are you?", + resp = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="Hello, how are you?", model_name=test_model, add_special_tokens=True, ) @@ -99,14 +97,14 @@ def test_tokenize_simple_text(self, grpc_stub, test_model): assert resp.success assert len(resp.input_ids) > 0 - def test_tokenize_returns_offset_pairs(self, grpc_stub, test_model, tokenizer_service: TokenizerService): - """Tokenize returns offset_pairs alongside token IDs.""" + def test_render_returns_offset_pairs(self, grpc_stub, test_model): + """Render returns offset_pairs alongside token IDs.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) - resp = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="Hello world", + resp = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="Hello world", model_name=test_model, add_special_tokens=True, ) @@ -114,30 +112,25 @@ def test_tokenize_returns_offset_pairs(self, grpc_stub, test_model, tokenizer_se assert resp.success # offset_pairs is a flat list of [start, end, start, end, ...] assert len(resp.offset_pairs) == 2 * len(resp.input_ids) - - # Verify token count matches tokenizer - tokenizer, _ = tokenizer_service.get_tokenizer_for_model(test_model) - expected_tokens = tokenizer.encode("Hello world", add_special_tokens=True) - assert list(resp.input_ids) == expected_tokens - def test_tokenize_without_special_tokens(self, grpc_stub, tokenizer_service: TokenizerService): - """Tokenize with add_special_tokens=False omits special tokens.""" + def test_render_without_special_tokens(self, grpc_stub, test_model): + """Render with add_special_tokens=False omits special tokens.""" - model_name = "google-bert/bert-base-uncased" + model_name = "deepseek-ai/DeepSeek-R1" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=model_name) + tokenizer_pb2.InitializeTokenizerRequest(model=model_name) ) - with_special = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="test", + with_special = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="test", model_name=model_name, add_special_tokens=True, ) ) - without_special = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="test", + without_special = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="test", model_name=model_name, add_special_tokens=False, ) @@ -145,25 +138,14 @@ def test_tokenize_without_special_tokens(self, grpc_stub, tokenizer_service: Tok assert with_special.success and without_special.success # With special tokens should produce > tokens as without. assert len(with_special.input_ids) > len(without_special.input_ids) - - # Verify special tokens using actual tokenizer - tokenizer, _ = tokenizer_service.get_tokenizer_for_model(model_name) - - # BERT adds [CLS] at start and [SEP] at end - assert with_special.input_ids[0] == tokenizer.cls_token_id - assert with_special.input_ids[-1] == tokenizer.sep_token_id - # Without special tokens should not have [CLS] or [SEP] - assert without_special.input_ids[0] != tokenizer.cls_token_id - assert without_special.input_ids[-1] != tokenizer.sep_token_id - - def test_tokenize_empty_input(self, grpc_stub, test_model): + def test_render_empty_input(self, grpc_stub, test_model): grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) - resp = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="", + resp = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="", model_name=test_model, add_special_tokens=False, ) @@ -171,15 +153,15 @@ def test_tokenize_empty_input(self, grpc_stub, test_model): # An empty input should still succeed (may return 0 or only special tokens). assert resp.success - def test_tokenize_long_input(self, grpc_stub, test_model): - """Tokenize handles a long input string.""" + def test_render_long_input(self, grpc_stub, test_model): + """Render handles a long input string.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) long_text = "Hello world. " * 100_000 - resp = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input=long_text, + resp = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text=long_text, model_name=test_model, add_special_tokens=True, ) @@ -187,63 +169,57 @@ def test_tokenize_long_input(self, grpc_stub, test_model): assert resp.success assert len(resp.input_ids) > 100 # Should have many tokens. - def test_tokenize_special_characters(self, grpc_stub, test_model, tokenizer_service: TokenizerService): - """Tokenize handles special / unicode characters.""" + def test_render_special_characters(self, grpc_stub, test_model): + """Render handles special / unicode characters.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) test_input = "Hello 你好 مرحبا 🌍 <|special|>" - resp = grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input=test_input, + resp = grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text=test_input, model_name=test_model, add_special_tokens=True, ) ) assert resp.success assert len(resp.input_ids) > 0 - - # Verify tokenization matches actual tokenizer - tokenizer, _ = tokenizer_service.get_tokenizer_for_model(test_model) - expected_tokens = tokenizer.encode(test_input, add_special_tokens=True) - assert list(resp.input_ids) == expected_tokens - - def test_tokenize_uninitialized_model(self, grpc_stub): - """Tokenize for a model that was never initialized returns an error.""" + def test_render_uninitialized_model(self, grpc_stub): + """Render for a model that was never initialized returns an error.""" with pytest.raises(grpc.RpcError) as exc_info: - grpc_stub.Tokenize( - tokenizer_pb2.TokenizeRequest( - input="Hello", + grpc_stub.Render( + tokenizer_pb2.RenderRequest( + text="Hello", model_name="meta-llama/Meta-Llama-3-8B", # Assuming this model is not initialized in this test add_special_tokens=True, ) ) assert exc_info.value.code() == grpc.StatusCode.INTERNAL - def test_tokenize_deterministic(self, grpc_stub, test_model): - """Tokenizing the same input twice produces identical results.""" + def test_render_deterministic(self, grpc_stub, test_model): + """Rendering the same input twice produces identical results.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) - req = tokenizer_pb2.TokenizeRequest( - input="Determinism check.", + req = tokenizer_pb2.RenderRequest( + text="Determinism check.", model_name=test_model, add_special_tokens=True, ) - resp1 = grpc_stub.Tokenize(req) - resp2 = grpc_stub.Tokenize(req) + resp1 = grpc_stub.Render(req) + resp2 = grpc_stub.Render(req) assert list(resp1.input_ids) == list(resp2.input_ids) assert list(resp1.offset_pairs) == list(resp2.offset_pairs) # --------------------------------------------------------------------------- -# RenderChatTemplate +# RenderChat # --------------------------------------------------------------------------- -class TestRenderChatTemplate: - """Tests for the RenderChatTemplate RPC. +class TestRenderChat: + """Tests for the RenderChat RPC. NOTE: Not all models ship with a chat template (e.g. openai-community/gpt2 does not). Tests that require a chat template are expected to fail @@ -251,25 +227,21 @@ class TestRenderChatTemplate: """ def _make_request(self, model_name, messages, add_generation_prompt=True): - """Helper: build a ChatTemplateRequest.""" - turns = [ - tokenizer_pb2.ConversationTurn( - messages=[ - tokenizer_pb2.ChatMessage(role=m["role"], content=m["content"]) - for m in messages - ] - ) + """Helper: build a RenderChatRequest.""" + chat_messages = [ + tokenizer_pb2.ChatMessage(role=m["role"], content=m["content"]) + for m in messages ] - return tokenizer_pb2.ChatTemplateRequest( - conversation_turns=turns, + return tokenizer_pb2.RenderChatRequest( + conversation=chat_messages, model_name=model_name, add_generation_prompt=add_generation_prompt, ) def test_render_multi_turn(self, grpc_stub, test_model): - """RenderChatTemplate handles a multi-turn conversation.""" + """RenderChat handles a multi-turn conversation.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) messages = [ {"role": "user", "content": "What is 2+2?"}, @@ -277,51 +249,72 @@ def test_render_multi_turn(self, grpc_stub, test_model): {"role": "user", "content": "And 3+3?"}, ] - resp = grpc_stub.RenderChatTemplate( + resp = grpc_stub.RenderChat( self._make_request(test_model, messages) ) assert resp.success - - for msg in messages: - assert msg["role"] in resp.rendered_prompt - assert msg["content"] in resp.rendered_prompt + assert len(resp.input_ids) > 0 def test_render_empty_messages(self, grpc_stub, test_model): - """RenderChatTemplate with empty messages.""" + """RenderChat with empty messages.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=test_model) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) ) - # Empty messages should raise an error - with pytest.raises(grpc.RpcError) as exc_info: - grpc_stub.RenderChatTemplate( - self._make_request(test_model, []) - ) - assert exc_info.value.code() == grpc.StatusCode.INTERNAL + # Empty messages should still succeed, may return special tokens only + resp = grpc_stub.RenderChat( + self._make_request(test_model, []) + ) + # Should succeed but may have only special tokens + assert resp.success def test_render_uninitialized_model(self, grpc_stub): - """RenderChatTemplate for an uninitialized model returns an error.""" + """RenderChat for an uninitialized model returns an error.""" messages = [{"role": "user", "content": "Hi"}] with pytest.raises(grpc.RpcError) as exc_info: - grpc_stub.RenderChatTemplate( + grpc_stub.RenderChat( self._make_request("openai-community/gpt2", messages) ) assert exc_info.value.code() == grpc.StatusCode.INTERNAL - - def test_render_for_model_without_template(self, grpc_stub): - """RenderChatTemplate for a model without a chat template returns an error.""" - - model_name = "openai-community/gpt2" # This model is known to lack a chat template. + def test_render_with_tools(self, grpc_stub, test_model): + """RenderChat with tools parameter.""" grpc_stub.InitializeTokenizer( - tokenizer_pb2.InitializeTokenizerRequest(model_name=model_name) + tokenizer_pb2.InitializeTokenizerRequest(model=test_model) + ) + messages = [ + {"role": "user", "content": "What is 2+2?"}, + ] + + # Create a simple tool definition + tool = tokenizer_pb2.Value(struct_value=tokenizer_pb2.StructValue(fields={ + "type": tokenizer_pb2.Value(string_value="function"), + "function": tokenizer_pb2.Value(struct_value=tokenizer_pb2.StructValue(fields={ + "name": tokenizer_pb2.Value(string_value="calculator"), + "description": tokenizer_pb2.Value(string_value="A simple calculator"), + "parameters": tokenizer_pb2.Value(struct_value=tokenizer_pb2.StructValue(fields={ + "type": tokenizer_pb2.Value(string_value="object"), + "properties": tokenizer_pb2.Value(struct_value=tokenizer_pb2.StructValue(fields={ + "operation": tokenizer_pb2.Value(struct_value=tokenizer_pb2.StructValue(fields={ + "type": tokenizer_pb2.Value(string_value="string"), + "enum": tokenizer_pb2.Value(list_value=tokenizer_pb2.ListValue(values=[ + tokenizer_pb2.Value(string_value="add"), + tokenizer_pb2.Value(string_value="subtract") + ])) + })) + })) + })) + })) + })) + + req = tokenizer_pb2.RenderChatRequest( + conversation=[tokenizer_pb2.ChatMessage(role=m["role"], content=m["content"]) for m in messages], + tools=[tool], + model_name=test_model, + add_generation_prompt=True, ) - messages = [{"role": "user", "content": "Hi"}] - with pytest.raises(grpc.RpcError) as exc_info: - grpc_stub.RenderChatTemplate( - self._make_request(model_name, messages) - ) - assert exc_info.value.code() == grpc.StatusCode.INTERNAL - assert "chat template" in str(exc_info.value.details()).lower() + resp = grpc_stub.RenderChat(req) + assert resp.success + assert len(resp.input_ids) > 0 diff --git a/services/uds_tokenizer/tokenizer_grpc_service.py b/services/uds_tokenizer/tokenizer_grpc_service.py index c998d04a..15bc2716 100644 --- a/services/uds_tokenizer/tokenizer_grpc_service.py +++ b/services/uds_tokenizer/tokenizer_grpc_service.py @@ -17,106 +17,232 @@ import grpc from grpc_reflection.v1alpha import reflection import logging +import json +import threading import os import sys + +from readerwriterlock import rwlock # Ensure current directory is on sys.path for protobuf imports sys.path.append(os.path.dirname(__file__)) # Import protobuf-generated modules import tokenizerpb.tokenizer_pb2 as tokenizer_pb2 import tokenizerpb.tokenizer_pb2_grpc as tokenizer_pb2_grpc -from tokenizer_service.tokenizer import TokenizerService from utils.thread_pool_utils import get_thread_pool_size +# Add the preprocessing directory to the Python path to import tokenizer_wrapper +runtime_path = '/app/preprocessing/chat_completions' +current_file_dir = os.path.dirname(os.path.abspath(__file__)) +dev_path = os.path.join(os.path.dirname(os.path.dirname(current_file_dir)), 'pkg', 'preprocessing', 'chat_completions') +dev_path = os.path.normpath(dev_path) # Normalize the path to resolve '..' -class TokenizationServiceServicer(tokenizer_pb2_grpc.TokenizationServiceServicer): - """Synchronous gRPC service implementation class, optimized for CPU-intensive operations""" +if runtime_path not in sys.path: + sys.path.insert(0, runtime_path) +if dev_path not in sys.path: + sys.path.insert(0, dev_path) - def __init__(self, tokenizer_service: TokenizerService): - self.tokenizer_service = tokenizer_service - logging.info("TokenizationServiceServicer initialized") +# Import the tokenizer wrapper functions +from tokenizer_wrapper import render, render_chat, get_or_create_tokenizer_key - def Tokenize(self, request, context): - """Implement the synchronous Tokenize RPC method""" - try: - # logging.info(f"Received tokenize request for model: {request.model_name}") - # Use tokenizer_service for tokenization, with add_special_tokens from request - batch_encoding = self.tokenizer_service.tokenize_and_process( - request.input, - request.add_special_tokens, - request.model_name - ) +class TokenizationServiceServicer(tokenizer_pb2_grpc.TokenizationServiceServicer): + """Synchronous gRPC service implementation class, optimized for CPU-intensive operations""" - # Convert result format - input_ids = batch_encoding['input_ids'] - offset_mapping = batch_encoding.get('offset_mapping', []) + def __init__(self): + self._model_to_key_map = {} + self._map_lock = rwlock.RWLockWrite() # Reader-writer lock for thread-safe access + + def _get_tokenizer_key(self, model_name): + """Thread-safe method to get tokenizer key for a model name""" + with self._map_lock.gen_rlock(): + return self._model_to_key_map.get(model_name) + + def _set_tokenizer_key(self, model_name, tokenizer_key): + """Thread-safe method to set tokenizer key for a model name""" + with self._map_lock.gen_wlock(): + self._model_to_key_map[model_name] = tokenizer_key + + def _has_model(self, model_name): + """Thread-safe method to check if model is initialized""" + with self._map_lock.gen_rlock(): + return model_name in self._model_to_key_map + + def _protobuf_value_to_python(self, value): + """Convert protobuf Value to Python native type""" + if value.HasField("string_value"): + return value.string_value + elif value.HasField("number_value"): + return value.number_value + elif value.HasField("bool_value"): + return value.bool_value + elif value.HasField("list_value"): + return [self._protobuf_value_to_python(v) for v in value.list_value.values] + elif value.HasField("struct_value"): + result = {} + for key, val in value.struct_value.fields.items(): + result[key] = self._protobuf_value_to_python(val) + return result + else: + return None + + def Render(self, request, context): + """Implement the synchronous Render RPC method""" + try: + # Get tokenizer key from model name mapping + model_name = request.model_name + if not self._has_model(model_name): + # Model not initialized, raise gRPC error + logging.warning(f"Model {request.model_name} not initialized, cannot render") + context.abort(grpc.StatusCode.INTERNAL, f"Model {model_name} not initialized") + return tokenizer_pb2.RenderResponse() + + tokenizer_key = self._get_tokenizer_key(model_name) + + # Prepare request for Python wrapper function + render_request = { + "key": tokenizer_key, + "text": request.text, + "add_special_tokens": request.add_special_tokens + } + + # Call the Python render function directly + result_json = render(json.dumps(render_request)) + logging.debug(f"Render result: {result_json}") + result_data = json.loads(result_json) + + input_ids = result_data.get('input_ids', []) + offset_mapping = result_data.get('offset_mapping', []) # Create offset_pairs format (flattened array of [start, end, start, end, ...]) offset_pairs = [] for offset in offset_mapping: offset_pairs.extend([int(offset[0]), int(offset[1])]) - response = tokenizer_pb2.TokenizeResponse( + response = tokenizer_pb2.RenderResponse( input_ids=list(input_ids), - offset_pairs=offset_pairs, # Only use offset_pairs field + offset_pairs=offset_pairs, success=True ) - # logging.info(f"Tokenization completed with {len(input_ids)} tokens") return response except Exception as e: - logging.error(f"Tokenization failed: {e}", exc_info=True) + logging.error(f"Render failed: {e}", exc_info=True) context.abort(grpc.StatusCode.INTERNAL, str(e)) - def RenderChatTemplate(self, request, context): - """Implement the synchronous RenderChatTemplate RPC method""" + def RenderChat(self, request, context): + """Implement the synchronous RenderChat RPC method""" try: - # logging.info(f"Received chat template request") - - # Convert the nested conversation turns to a flat list of messages - messages = [] - for turn in request.conversation_turns: - for msg in turn.messages: - messages.append({"role": msg.role, "content": msg.content}) + # Get tokenizer key from model name mapping + model_name = request.model_name + if not self._has_model(model_name): + # Model not initialized, raise gRPC error + logging.warning(f"Model {request.model_name} not initialized, cannot render chat") + context.abort(grpc.StatusCode.INTERNAL, f"Model {model_name} not initialized") + return tokenizer_pb2.RenderResponse() + + tokenizer_key = self._get_tokenizer_key(model_name) + + # Convert conversation messages (list of ChatMessage objects) to the expected format + conversation = [] + for msg in request.conversation: + conversation.append({"role": msg.role, "content": msg.content}) + + # Convert tools from protobuf Value array to Python objects + tools = [] + for tool in request.tools: + tools.append(self._protobuf_value_to_python(tool)) + + # Convert documents from protobuf Value array to Python objects + documents = [] + for document in request.documents: + documents.append(self._protobuf_value_to_python(document)) + + # Convert chat_template_kwargs from protobuf format to dict + chat_template_kwargs = {} + for key, value in request.chat_template_kwargs.items(): + chat_template_kwargs[key] = self._protobuf_value_to_python(value) + + # Prepare request for Python wrapper function + render_chat_request = { + "key": tokenizer_key, + "conversation": conversation, + "return_assistant_tokens_mask": request.return_assistant_tokens_mask, + "continue_final_message": request.continue_final_message, + "add_generation_prompt": request.add_generation_prompt, + "chat_template_kwargs": chat_template_kwargs + } + + # Add optional fields if they exist + if request.HasField("chat_template"): + render_chat_request["chat_template"] = request.chat_template + + if tools: + render_chat_request["tools"] = tools + + if documents: + render_chat_request["documents"] = documents + + # Call the Python render_chat function directly + result_json = render_chat(json.dumps(render_chat_request)) + logging.debug(f"RenderChat result: {result_json}") + result_data = json.loads(result_json) + + input_ids = result_data.get('input_ids', []) + offset_mapping = result_data.get('offset_mapping', []) - # Call tokenizer_service method with model name - prompt = self.tokenizer_service.apply_template(messages, request.model_name) + # Create offset_pairs format (flattened array of [start, end, start, end, ...]) + offset_pairs = [] + for offset in offset_mapping: + offset_pairs.extend([int(offset[0]), int(offset[1])]) - response = tokenizer_pb2.ChatTemplateResponse( - rendered_prompt=prompt, + response = tokenizer_pb2.RenderResponse( + input_ids=list(input_ids), + offset_pairs=offset_pairs, success=True ) - # logging.info(f"Chat template rendered successfully") return response except Exception as e: - logging.error(f"Chat template rendering failed: {e}", exc_info=True) + logging.error(f"RenderChat failed: {e}", exc_info=True) context.abort(grpc.StatusCode.INTERNAL, str(e)) def InitializeTokenizer(self, request, context): """Implement the synchronous InitializeTokenizer RPC method""" try: - logging.info(f"Initializing tokenizer for model: {request.model_name}") - - success = self.tokenizer_service.load_tokenizer( - request.model_name, - request.enable_thinking, - request.add_generation_prompt - ) + logging.info(f"Initializing tokenizer for model: {request.model}") - if success: + # Check if tokenizer is already initialized for this model + model_name = request.model + if self._has_model(model_name): + logging.info(f"Tokenizer for model {request.model} already initialized") response = tokenizer_pb2.InitializeTokenizerResponse( success=True ) - else: - response = tokenizer_pb2.InitializeTokenizerResponse( - success=False, - error_message=f"Failed to initialize tokenizer for model: {request.model_name}" - ) + return response + + # Create tokenizer key request using parameters from the gRPC request + tokenizer_request = { + "is_local": request.is_local, + "model": request.model, + "revision": request.revision if request.HasField("revision") else None, + "token": request.token if request.HasField("token") else os.getenv("HF_TOKEN", ""), + "download_dir": request.download_dir if request.HasField("download_dir") else None + } + + # Create the tokenizer key which will cache the tokenizer + tokenizer_key = get_or_create_tokenizer_key(json.dumps(tokenizer_request)) + + # Store the mapping from model name to tokenizer key + self._set_tokenizer_key(model_name, tokenizer_key) + + # If we reach here, the tokenizer was successfully created/cached + response = tokenizer_pb2.InitializeTokenizerResponse( + success=True + ) return response @@ -128,7 +254,7 @@ def InitializeTokenizer(self, request, context): ) -def create_grpc_server(tokenizer_service: TokenizerService, uds_socket_path: str, thread_pool, tcp_port: str = ""): +def create_grpc_server(uds_socket_path: str, thread_pool, tcp_port: str = ""): """Create a synchronous gRPC server. Args: @@ -156,7 +282,7 @@ def create_grpc_server(tokenizer_service: TokenizerService, uds_socket_path: str ) # Create service implementation - servicer = TokenizationServiceServicer(tokenizer_service) + servicer = TokenizationServiceServicer() # Register service tokenizer_pb2_grpc.add_TokenizationServiceServicer_to_server(servicer, server) diff --git a/services/uds_tokenizer/tokenizer_service/tokenizer.py b/services/uds_tokenizer/tokenizer_service/tokenizer.py deleted file mode 100644 index fe346423..00000000 --- a/services/uds_tokenizer/tokenizer_service/tokenizer.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2025 The llm-d Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tokenizer service for handling LLM tokenization operations.""" - -import logging -import os -from pathlib import Path -from dataclasses import dataclass -from typing import Optional, List, Dict, Union -from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from transformers.tokenization_utils_base import BatchEncoding -from modelscope import snapshot_download -from huggingface_hub import snapshot_download as hf_snapshot_download -from .exceptions import TokenizerError, ModelDownloadError, TokenizationError - -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - -DEFAULT_TOKENIZERS_DIR = str(Path(__file__).parent.parent / "tokenizers") - - -@dataclass -class TokenizerConfig: - """Configuration for tokenizer processing""" - model: str - enable_thinking: bool = False - add_generation_prompt: bool = True - - -class TokenizerService: - """Service for handling tokenizer operations""" - - def __init__(self, config: TokenizerConfig = None): - """Initialize service with optional configuration""" - self.tokenizers = {} # Dictionary to store multiple tokenizers by model name - self.configs = {} # Dictionary to store configurations by model name - - # Set tokenizers directory (configurable via TOKENIZERS_DIR env var) - self.tokenizers_dir = os.environ.get('TOKENIZERS_DIR', DEFAULT_TOKENIZERS_DIR) - - # If a config is provided, initialize the default tokenizer - if config: - self.tokenizer = self._create_tokenizer(config.model) - self.config = config - self.tokenizers[config.model] = self.tokenizer - self.configs[config.model] = config - - def _create_tokenizer(self, model_identifier: str) -> AnyTokenizer: - """Create a tokenizer, using cached files if available or downloading from ModelScope or Hugging Face""" - # Check if the model_identifier is a remote model name or a local path - # More robust check similar to what vLLM does - is_remote_model = self._is_remote_model(model_identifier) - - # For local paths, use directly - if not is_remote_model: - logging.info(f"Loading tokenizer from {model_identifier}") - base_tokenizer = AutoTokenizer.from_pretrained( - model_identifier, - trust_remote_code=True, - padding_side="left", - truncation_side="left", - use_fast=True, - ) - return base_tokenizer - - # Determine download source: ModelScope (if USE_MODELSCOPE=true) or Hugging Face (default) - use_modelscope = os.getenv('USE_MODELSCOPE', 'false').lower() == 'true' - - # Convert model identifier to local path (e.g., qwen/Qwen2-7B -> tokenizers/qwen/Qwen2-7B) - org_name, model_name = model_identifier.split('/', 1) - local_model_path = os.path.join(self.tokenizers_dir, org_name, model_name) - - # If the model is already cached, use the cached version - # Check that required files exist before trying to load - required_files = [ - "config.json", - "tokenizer.json", - ] - if (os.path.exists(local_model_path) and - all(os.path.exists(os.path.join(local_model_path, f)) for f in required_files)): - logging.info(f"Using cached tokenizer from {local_model_path}") - base_tokenizer = AutoTokenizer.from_pretrained( - local_model_path, - trust_remote_code=True, - padding_side="left", - truncation_side="left", - use_fast=True, - ) - return base_tokenizer - - # Download the tokenizer files from ModelScope or Hugging Face - if use_modelscope: - return self._download_from_modelscope(model_identifier, local_model_path) - else: - return self._download_from_huggingface(model_identifier, local_model_path) - - def _download_from_modelscope(self, model_identifier: str, local_model_path: str) -> AnyTokenizer: - """Download tokenizer files from ModelScope""" - logging.info(f"Downloading tokenizer for {model_identifier} from ModelScope") - try: - # Ensure the local model directory exists - os.makedirs(local_model_path, exist_ok=True) - - # Download only the tokenizer related files from ModelScope - snapshot_download( - model_identifier, - local_dir=local_model_path, - allow_patterns=[ - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.json", - "merges.txt", - "config.json", - "generation_config.json" - ] - ) - logging.info(f"Successfully downloaded tokenizer to {local_model_path}") - except Exception as e: - # Clean up potentially incomplete download directory - if os.path.exists(local_model_path) and not os.listdir(local_model_path): - os.rmdir(local_model_path) - logging.info(f"Removed empty directory {local_model_path}") - logging.error(f"Failed to download tokenizer for {model_identifier} from ModelScope: {e}") - raise ModelDownloadError(f"Failed to download model from ModelScope: {e}") from e - - # Load the tokenizer from the downloaded files - try: - base_tokenizer = AutoTokenizer.from_pretrained( - local_model_path, - trust_remote_code=True, - padding_side="left", - truncation_side="left", - use_fast=True, - ) - return base_tokenizer - except Exception as e: - logging.error(f"Failed to load tokenizer from {local_model_path}: {e}") - raise TokenizerError(f"Failed to load tokenizer: {e}") from e - - def _download_from_huggingface(self, model_identifier: str, local_model_path: str) -> AnyTokenizer: - """Download tokenizer files from Hugging Face""" - logging.info(f"Downloading tokenizer for {model_identifier} from Hugging Face") - try: - # Ensure the local model directory exists - os.makedirs(local_model_path, exist_ok=True) - - # Download only the tokenizer related files from Hugging Face - hf_snapshot_download( - model_identifier, - local_dir=local_model_path, - allow_patterns=[ - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "vocab.json", - "merges.txt", - "config.json", - "generation_config.json" - ] - ) - logging.info(f"Successfully downloaded tokenizer to {local_model_path}") - except Exception as e: - # Clean up potentially incomplete download directory - if os.path.exists(local_model_path) and not os.listdir(local_model_path): - os.rmdir(local_model_path) - logging.info(f"Removed empty directory {local_model_path}") - logging.error(f"Failed to download tokenizer for {model_identifier} from Hugging Face: {e}") - raise ModelDownloadError(f"Failed to download model from Hugging Face: {e}") from e - - # Load the tokenizer from the downloaded files - try: - base_tokenizer = AutoTokenizer.from_pretrained( - local_model_path, - trust_remote_code=True, - padding_side="left", - truncation_side="left", - use_fast=True, - ) - return base_tokenizer - except Exception as e: - logging.error(f"Failed to load tokenizer from {local_model_path}: {e}") - raise TokenizerError(f"Failed to load tokenizer: {e}") from e - - def _is_remote_model(self, model_identifier: str) -> bool: - """Check if the model identifier is a remote model name or a local path.""" - # Check if it's an absolute path - if os.path.isabs(model_identifier): - return False - - # Check if it's a relative path (starts with ./ or ../) - if model_identifier.startswith("./") or model_identifier.startswith("../"): - return False - - # Check if it's a local directory that exists - if os.path.exists(model_identifier): - return False - - # Check for protocol prefixes (s3://, etc.) - if "://" in model_identifier.split("/")[0]: - return False - - # If none of the above, it's likely a remote model identifier - # containing organization/model format - return "/" in model_identifier - - def load_tokenizer(self, model_name: str, enable_thinking: bool = False, add_generation_prompt: bool = True) -> bool: - """Load a tokenizer for a specific model""" - try: - config = TokenizerConfig( - model=model_name, - enable_thinking=enable_thinking, - add_generation_prompt=add_generation_prompt - ) - - tokenizer = self._create_tokenizer(model_name) - self.tokenizers[model_name] = tokenizer - self.configs[model_name] = config - - logging.info(f"Successfully initialized tokenizer for model: {model_name}") - return True - except Exception as e: - logging.error(f"Failed to initialize tokenizer for model {model_name}: {e}") - return False - - def get_tokenizer_for_model(self, model_name: str): - """Get the tokenizer for a specific model""" - if model_name not in self.tokenizers: - raise TokenizerError(f"Tokenizer not initialized for model: {model_name}") - - return self.tokenizers[model_name], self.configs[model_name] - - def apply_template(self, messages: List[Dict[str, str]], model_name: str) -> str: - """Apply chat template to messages""" - try: - tokenizer, config = self.get_tokenizer_for_model(model_name) - prompt = tokenizer.apply_chat_template( - conversation=messages, - tokenize=False, - add_generation_prompt=config.add_generation_prompt, - enable_thinking=config.enable_thinking, - ) - - logging.debug(f"Prompt: {prompt}") - return prompt - except Exception as e: - logging.error(f"Failed to apply chat template: {e}") - raise TokenizationError(f"Failed to apply chat template: {e}") from e - - def tokenize_and_process(self, prompt: str, add_special_tokens: bool, model_name: str) -> BatchEncoding: - """ - Tokenize the prompt with the specified add_special_tokens value. - """ - try: - tokenizer, _ = self.get_tokenizer_for_model(model_name) - token_id_offsets = tokenizer.encode_plus( - prompt, - add_special_tokens=add_special_tokens, - return_offsets_mapping=True - ) - logging.debug(f"Encoded prompt: {token_id_offsets}") - return token_id_offsets - except Exception as e: - logging.error(f"Failed to tokenize prompt: {e}") - raise TokenizationError(f"Failed to tokenize prompt: {e}") from e \ No newline at end of file diff --git a/services/uds_tokenizer/tokenizerpb/tokenizer_pb2.py b/services/uds_tokenizer/tokenizerpb/tokenizer_pb2.py index ba7837a9..8f156f2f 100644 --- a/services/uds_tokenizer/tokenizerpb/tokenizer_pb2.py +++ b/services/uds_tokenizer/tokenizerpb/tokenizer_pb2.py @@ -38,7 +38,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1btokenizerpb/tokenizer.proto\x12\x0ctokenization\"P\n\x0fTokenizeRequest\x12\r\n\x05input\x18\x01 \x01(\t\x12\x12\n\nmodel_name\x18\x02 \x01(\t\x12\x1a\n\x12\x61\x64\x64_special_tokens\x18\x03 \x01(\x08\"c\n\x10TokenizeResponse\x12\x11\n\tinput_ids\x18\x01 \x03(\r\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\x12\x14\n\x0coffset_pairs\x18\x04 \x03(\r\"?\n\x10\x43onversationTurn\x12+\n\x08messages\x18\x01 \x03(\x0b\x32\x19.tokenization.ChatMessage\"\xe3\x03\n\x13\x43hatTemplateRequest\x12:\n\x12\x63onversation_turns\x18\x01 \x03(\x0b\x32\x1e.tokenization.ConversationTurn\x12,\n\x05tools\x18\x02 \x03(\x0b\x32\x1d.tokenization.ToolDescription\x12)\n\tdocuments\x18\x03 \x03(\x0b\x32\x16.tokenization.Document\x12\x15\n\rchat_template\x18\x04 \x01(\t\x12$\n\x1creturn_assistant_tokens_mask\x18\x05 \x01(\x08\x12\x1e\n\x16\x63ontinue_final_message\x18\x06 \x01(\x08\x12\x1d\n\x15\x61\x64\x64_generation_prompt\x18\x07 \x01(\x08\x12W\n\x14\x63hat_template_kwargs\x18\x08 \x03(\x0b\x32\x39.tokenization.ChatTemplateRequest.ChatTemplateKwargsEntry\x12\x12\n\nmodel_name\x18\t \x01(\t\x1aN\n\x17\x43hatTemplateKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\"\x8a\x01\n\x0fToolDescription\x12\x35\n\x04tool\x18\x01 \x03(\x0b\x32\'.tokenization.ToolDescription.ToolEntry\x1a@\n\tToolEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\"\x88\x01\n\x08\x44ocument\x12\x36\n\x08\x64ocument\x18\x01 \x03(\x0b\x32$.tokenization.Document.DocumentEntry\x1a\x44\n\rDocumentEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\"\xb8\x01\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x16\n\x0cnumber_value\x18\x02 \x01(\x01H\x00\x12\x14\n\nbool_value\x18\x03 \x01(\x08H\x00\x12-\n\nlist_value\x18\x04 \x01(\x0b\x32\x17.tokenization.ListValueH\x00\x12\x31\n\x0cstruct_value\x18\x05 \x01(\x0b\x32\x19.tokenization.StructValueH\x00\x42\x07\n\x05value\"0\n\tListValue\x12#\n\x06values\x18\x01 \x03(\x0b\x32\x13.tokenization.Value\"\x88\x01\n\x0bStructValue\x12\x35\n\x06\x66ields\x18\x01 \x03(\x0b\x32%.tokenization.StructValue.FieldsEntry\x1a\x42\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\"W\n\x14\x43hatTemplateResponse\x12\x17\n\x0frendered_prompt\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\"h\n\x1aInitializeTokenizerRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65nable_thinking\x18\x02 \x01(\x08\x12\x1d\n\x15\x61\x64\x64_generation_prompt\x18\x03 \x01(\x08\"E\n\x1bInitializeTokenizerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xa9\x02\n\x13TokenizationService\x12I\n\x08Tokenize\x12\x1d.tokenization.TokenizeRequest\x1a\x1e.tokenization.TokenizeResponse\x12[\n\x12RenderChatTemplate\x12!.tokenization.ChatTemplateRequest\x1a\".tokenization.ChatTemplateResponse\x12j\n\x13InitializeTokenizer\x12(.tokenization.InitializeTokenizerRequest\x1a).tokenization.InitializeTokenizerResponseB=Z;github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpbb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1btokenizerpb/tokenizer.proto\x12\x0ctokenization\",\n\x0b\x43hatMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\"\xb8\x01\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x16\n\x0cnumber_value\x18\x02 \x01(\x01H\x00\x12\x14\n\nbool_value\x18\x03 \x01(\x08H\x00\x12-\n\nlist_value\x18\x04 \x01(\x0b\x32\x17.tokenization.ListValueH\x00\x12\x31\n\x0cstruct_value\x18\x05 \x01(\x0b\x32\x19.tokenization.StructValueH\x00\x42\x07\n\x05value\"0\n\tListValue\x12#\n\x06values\x18\x01 \x03(\x0b\x32\x13.tokenization.Value\"\x88\x01\n\x0bStructValue\x12\x35\n\x06\x66ields\x18\x01 \x03(\x0b\x32%.tokenization.StructValue.FieldsEntry\x1a\x42\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\"\xab\x01\n\x1aInitializeTokenizerRequest\x12\x10\n\x08is_local\x18\x01 \x01(\x08\x12\r\n\x05model\x18\x02 \x01(\t\x12\x15\n\x08revision\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x12\n\x05token\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x19\n\x0c\x64ownload_dir\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0b\n\t_revisionB\x08\n\x06_tokenB\x0f\n\r_download_dir\"E\n\x1bInitializeTokenizerResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"M\n\rRenderRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x12\n\nmodel_name\x18\x02 \x01(\t\x12\x1a\n\x12\x61\x64\x64_special_tokens\x18\x03 \x01(\x08\"\xc3\x04\n\x11RenderChatRequest\x12/\n\x0c\x63onversation\x18\x01 \x03(\x0b\x32\x19.tokenization.ChatMessage\x12\"\n\x05tools\x18\x02 \x03(\x0b\x32\x13.tokenization.Value\x12&\n\tdocuments\x18\x03 \x03(\x0b\x32\x13.tokenization.Value\x12\x1a\n\rchat_template\x18\x04 \x01(\tH\x00\x88\x01\x01\x12)\n\x1creturn_assistant_tokens_mask\x18\x05 \x01(\x08H\x01\x88\x01\x01\x12#\n\x16\x63ontinue_final_message\x18\x06 \x01(\x08H\x02\x88\x01\x01\x12\"\n\x15\x61\x64\x64_generation_prompt\x18\x07 \x01(\x08H\x03\x88\x01\x01\x12U\n\x14\x63hat_template_kwargs\x18\x08 \x03(\x0b\x32\x37.tokenization.RenderChatRequest.ChatTemplateKwargsEntry\x12\x12\n\nmodel_name\x18\t \x01(\t\x1aN\n\x17\x43hatTemplateKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.tokenization.Value:\x02\x38\x01\x42\x10\n\x0e_chat_templateB\x1f\n\x1d_return_assistant_tokens_maskB\x19\n\x17_continue_final_messageB\x18\n\x16_add_generation_prompt\"a\n\x0eRenderResponse\x12\x11\n\tinput_ids\x18\x01 \x03(\r\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x15\n\rerror_message\x18\x03 \x01(\t\x12\x14\n\x0coffset_pairs\x18\x04 \x03(\r2\x93\x02\n\x13TokenizationService\x12j\n\x13InitializeTokenizer\x12(.tokenization.InitializeTokenizerRequest\x1a).tokenization.InitializeTokenizerResponse\x12\x43\n\x06Render\x12\x1b.tokenization.RenderRequest\x1a\x1c.tokenization.RenderResponse\x12K\n\nRenderChat\x12\x1f.tokenization.RenderChatRequest\x1a\x1c.tokenization.RenderResponseB=Z;github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpbb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -46,48 +46,32 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'Z;github.com/llm-d/llm-d-kv-cache/api/tokenizerpb;tokenizerpb' - _globals['_CHATTEMPLATEREQUEST_CHATTEMPLATEKWARGSENTRY']._loaded_options = None - _globals['_CHATTEMPLATEREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_options = b'8\001' - _globals['_TOOLDESCRIPTION_TOOLENTRY']._loaded_options = None - _globals['_TOOLDESCRIPTION_TOOLENTRY']._serialized_options = b'8\001' - _globals['_DOCUMENT_DOCUMENTENTRY']._loaded_options = None - _globals['_DOCUMENT_DOCUMENTENTRY']._serialized_options = b'8\001' _globals['_STRUCTVALUE_FIELDSENTRY']._loaded_options = None _globals['_STRUCTVALUE_FIELDSENTRY']._serialized_options = b'8\001' - _globals['_TOKENIZEREQUEST']._serialized_start=45 - _globals['_TOKENIZEREQUEST']._serialized_end=125 - _globals['_TOKENIZERESPONSE']._serialized_start=127 - _globals['_TOKENIZERESPONSE']._serialized_end=226 - _globals['_CONVERSATIONTURN']._serialized_start=228 - _globals['_CONVERSATIONTURN']._serialized_end=291 - _globals['_CHATTEMPLATEREQUEST']._serialized_start=294 - _globals['_CHATTEMPLATEREQUEST']._serialized_end=777 - _globals['_CHATTEMPLATEREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_start=699 - _globals['_CHATTEMPLATEREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_end=777 - _globals['_CHATMESSAGE']._serialized_start=779 - _globals['_CHATMESSAGE']._serialized_end=823 - _globals['_TOOLDESCRIPTION']._serialized_start=826 - _globals['_TOOLDESCRIPTION']._serialized_end=964 - _globals['_TOOLDESCRIPTION_TOOLENTRY']._serialized_start=900 - _globals['_TOOLDESCRIPTION_TOOLENTRY']._serialized_end=964 - _globals['_DOCUMENT']._serialized_start=967 - _globals['_DOCUMENT']._serialized_end=1103 - _globals['_DOCUMENT_DOCUMENTENTRY']._serialized_start=1035 - _globals['_DOCUMENT_DOCUMENTENTRY']._serialized_end=1103 - _globals['_VALUE']._serialized_start=1106 - _globals['_VALUE']._serialized_end=1290 - _globals['_LISTVALUE']._serialized_start=1292 - _globals['_LISTVALUE']._serialized_end=1340 - _globals['_STRUCTVALUE']._serialized_start=1343 - _globals['_STRUCTVALUE']._serialized_end=1479 - _globals['_STRUCTVALUE_FIELDSENTRY']._serialized_start=1413 - _globals['_STRUCTVALUE_FIELDSENTRY']._serialized_end=1479 - _globals['_CHATTEMPLATERESPONSE']._serialized_start=1481 - _globals['_CHATTEMPLATERESPONSE']._serialized_end=1568 - _globals['_INITIALIZETOKENIZERREQUEST']._serialized_start=1570 - _globals['_INITIALIZETOKENIZERREQUEST']._serialized_end=1674 - _globals['_INITIALIZETOKENIZERRESPONSE']._serialized_start=1676 - _globals['_INITIALIZETOKENIZERRESPONSE']._serialized_end=1745 - _globals['_TOKENIZATIONSERVICE']._serialized_start=1748 - _globals['_TOKENIZATIONSERVICE']._serialized_end=2045 + _globals['_RENDERCHATREQUEST_CHATTEMPLATEKWARGSENTRY']._loaded_options = None + _globals['_RENDERCHATREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_options = b'8\001' + _globals['_CHATMESSAGE']._serialized_start=45 + _globals['_CHATMESSAGE']._serialized_end=89 + _globals['_VALUE']._serialized_start=92 + _globals['_VALUE']._serialized_end=276 + _globals['_LISTVALUE']._serialized_start=278 + _globals['_LISTVALUE']._serialized_end=326 + _globals['_STRUCTVALUE']._serialized_start=329 + _globals['_STRUCTVALUE']._serialized_end=465 + _globals['_STRUCTVALUE_FIELDSENTRY']._serialized_start=399 + _globals['_STRUCTVALUE_FIELDSENTRY']._serialized_end=465 + _globals['_INITIALIZETOKENIZERREQUEST']._serialized_start=468 + _globals['_INITIALIZETOKENIZERREQUEST']._serialized_end=639 + _globals['_INITIALIZETOKENIZERRESPONSE']._serialized_start=641 + _globals['_INITIALIZETOKENIZERRESPONSE']._serialized_end=710 + _globals['_RENDERREQUEST']._serialized_start=712 + _globals['_RENDERREQUEST']._serialized_end=789 + _globals['_RENDERCHATREQUEST']._serialized_start=792 + _globals['_RENDERCHATREQUEST']._serialized_end=1371 + _globals['_RENDERCHATREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_start=1189 + _globals['_RENDERCHATREQUEST_CHATTEMPLATEKWARGSENTRY']._serialized_end=1267 + _globals['_RENDERRESPONSE']._serialized_start=1373 + _globals['_RENDERRESPONSE']._serialized_end=1470 + _globals['_TOKENIZATIONSERVICE']._serialized_start=1473 + _globals['_TOKENIZATIONSERVICE']._serialized_end=1748 # @@protoc_insertion_point(module_scope) diff --git a/services/uds_tokenizer/tokenizerpb/tokenizer_pb2_grpc.py b/services/uds_tokenizer/tokenizerpb/tokenizer_pb2_grpc.py index 3c2ab9c0..c66c748d 100644 --- a/services/uds_tokenizer/tokenizerpb/tokenizer_pb2_grpc.py +++ b/services/uds_tokenizer/tokenizerpb/tokenizer_pb2_grpc.py @@ -49,43 +49,43 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.Tokenize = channel.unary_unary( - '/tokenization.TokenizationService/Tokenize', - request_serializer=tokenizerpb_dot_tokenizer__pb2.TokenizeRequest.SerializeToString, - response_deserializer=tokenizerpb_dot_tokenizer__pb2.TokenizeResponse.FromString, - _registered_method=True) - self.RenderChatTemplate = channel.unary_unary( - '/tokenization.TokenizationService/RenderChatTemplate', - request_serializer=tokenizerpb_dot_tokenizer__pb2.ChatTemplateRequest.SerializeToString, - response_deserializer=tokenizerpb_dot_tokenizer__pb2.ChatTemplateResponse.FromString, - _registered_method=True) self.InitializeTokenizer = channel.unary_unary( '/tokenization.TokenizationService/InitializeTokenizer', request_serializer=tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerRequest.SerializeToString, response_deserializer=tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerResponse.FromString, _registered_method=True) + self.Render = channel.unary_unary( + '/tokenization.TokenizationService/Render', + request_serializer=tokenizerpb_dot_tokenizer__pb2.RenderRequest.SerializeToString, + response_deserializer=tokenizerpb_dot_tokenizer__pb2.RenderResponse.FromString, + _registered_method=True) + self.RenderChat = channel.unary_unary( + '/tokenization.TokenizationService/RenderChat', + request_serializer=tokenizerpb_dot_tokenizer__pb2.RenderChatRequest.SerializeToString, + response_deserializer=tokenizerpb_dot_tokenizer__pb2.RenderResponse.FromString, + _registered_method=True) class TokenizationServiceServicer(object): """TokenizationService defines the gRPC service for tokenization """ - def Tokenize(self, request, context): - """Tokenize converts a text input to token IDs + def InitializeTokenizer(self, request, context): + """InitializeTokenizer initializes the tokenizer for a specific model """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def RenderChatTemplate(self, request, context): - """RenderChatTemplate renders a chat template with the given messages + def Render(self, request, context): + """Render renders (tokenizes) a text input """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def InitializeTokenizer(self, request, context): - """InitializeTokenizer initializes the tokenizer for a specific model + def RenderChat(self, request, context): + """RenderChat renders a chat conversation to tokens and offsets """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -94,21 +94,21 @@ def InitializeTokenizer(self, request, context): def add_TokenizationServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'Tokenize': grpc.unary_unary_rpc_method_handler( - servicer.Tokenize, - request_deserializer=tokenizerpb_dot_tokenizer__pb2.TokenizeRequest.FromString, - response_serializer=tokenizerpb_dot_tokenizer__pb2.TokenizeResponse.SerializeToString, - ), - 'RenderChatTemplate': grpc.unary_unary_rpc_method_handler( - servicer.RenderChatTemplate, - request_deserializer=tokenizerpb_dot_tokenizer__pb2.ChatTemplateRequest.FromString, - response_serializer=tokenizerpb_dot_tokenizer__pb2.ChatTemplateResponse.SerializeToString, - ), 'InitializeTokenizer': grpc.unary_unary_rpc_method_handler( servicer.InitializeTokenizer, request_deserializer=tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerRequest.FromString, response_serializer=tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerResponse.SerializeToString, ), + 'Render': grpc.unary_unary_rpc_method_handler( + servicer.Render, + request_deserializer=tokenizerpb_dot_tokenizer__pb2.RenderRequest.FromString, + response_serializer=tokenizerpb_dot_tokenizer__pb2.RenderResponse.SerializeToString, + ), + 'RenderChat': grpc.unary_unary_rpc_method_handler( + servicer.RenderChat, + request_deserializer=tokenizerpb_dot_tokenizer__pb2.RenderChatRequest.FromString, + response_serializer=tokenizerpb_dot_tokenizer__pb2.RenderResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'tokenization.TokenizationService', rpc_method_handlers) @@ -122,7 +122,7 @@ class TokenizationService(object): """ @staticmethod - def Tokenize(request, + def InitializeTokenizer(request, target, options=(), channel_credentials=None, @@ -135,9 +135,9 @@ def Tokenize(request, return grpc.experimental.unary_unary( request, target, - '/tokenization.TokenizationService/Tokenize', - tokenizerpb_dot_tokenizer__pb2.TokenizeRequest.SerializeToString, - tokenizerpb_dot_tokenizer__pb2.TokenizeResponse.FromString, + '/tokenization.TokenizationService/InitializeTokenizer', + tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerRequest.SerializeToString, + tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerResponse.FromString, options, channel_credentials, insecure, @@ -149,7 +149,7 @@ def Tokenize(request, _registered_method=True) @staticmethod - def RenderChatTemplate(request, + def Render(request, target, options=(), channel_credentials=None, @@ -162,9 +162,9 @@ def RenderChatTemplate(request, return grpc.experimental.unary_unary( request, target, - '/tokenization.TokenizationService/RenderChatTemplate', - tokenizerpb_dot_tokenizer__pb2.ChatTemplateRequest.SerializeToString, - tokenizerpb_dot_tokenizer__pb2.ChatTemplateResponse.FromString, + '/tokenization.TokenizationService/Render', + tokenizerpb_dot_tokenizer__pb2.RenderRequest.SerializeToString, + tokenizerpb_dot_tokenizer__pb2.RenderResponse.FromString, options, channel_credentials, insecure, @@ -176,7 +176,7 @@ def RenderChatTemplate(request, _registered_method=True) @staticmethod - def InitializeTokenizer(request, + def RenderChat(request, target, options=(), channel_credentials=None, @@ -189,9 +189,9 @@ def InitializeTokenizer(request, return grpc.experimental.unary_unary( request, target, - '/tokenization.TokenizationService/InitializeTokenizer', - tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerRequest.SerializeToString, - tokenizerpb_dot_tokenizer__pb2.InitializeTokenizerResponse.FromString, + '/tokenization.TokenizationService/RenderChat', + tokenizerpb_dot_tokenizer__pb2.RenderChatRequest.SerializeToString, + tokenizerpb_dot_tokenizer__pb2.RenderResponse.FromString, options, channel_credentials, insecure, diff --git a/tests/e2e/uds_tokenizer/uds_e2e_test.go b/tests/e2e/uds_tokenizer/uds_e2e_test.go index d22c076b..419a5550 100644 --- a/tests/e2e/uds_tokenizer/uds_e2e_test.go +++ b/tests/e2e/uds_tokenizer/uds_e2e_test.go @@ -46,40 +46,32 @@ func (s *UDSTokenizerSuite) TestTokenize() { s.Require().Equal(offsets1, offsets2, "repeated tokenization should be deterministic (offsets)") } -// TestTokenizeWithSpecialTokens verifies that Encode(prompt, true) includes special tokens -// and Encode(prompt, false) does not. -// Uses BERT model which always adds [CLS] and [SEP] tokens for strict greater-than comparison. +// TestTokenizeWithSpecialTokens verifies that Render(prompt) includes special tokens +// by comparing with a known tokenizer that has special tokens. +// Uses BERT model which always adds [CLS] and [SEP] tokens. func (s *UDSTokenizerSuite) TestTokenizeWithSpecialTokens() { // Switch to BERT model which adds [CLS] and [SEP] special tokens s.switchTokenizer("google-bert/bert-base-uncased") prompt := "Hello world" - tokensWithSpecial, _, err := s.tokenizer.Encode(prompt, true) - s.Require().NoError(err) - s.Require().NotEmpty(tokensWithSpecial) - - tokensWithoutSpecial, _, err := s.tokenizer.Encode(prompt, false) + tokens, _, err := s.tokenizer.Render(prompt) s.Require().NoError(err) - s.Require().NotEmpty(tokensWithoutSpecial) - - // BERT adds [CLS] at the start and [SEP] at the end when add_special_tokens=true. - // So tokens with special tokens should always be strictly greater. - s.Require().Greater(len(tokensWithSpecial), len(tokensWithoutSpecial), - "encoding with special tokens should produce more tokens (BERT adds [CLS] and [SEP])") + s.Require().NotEmpty(tokens) // Verify BERT-specific special token IDs bosTokenID := uint32(101) // [CLS] eosTokenID := uint32(102) // [SEP] - s.Require().Equal(bosTokenID, tokensWithSpecial[0], "first token should be [CLS] (101)") - s.Require().Equal(eosTokenID, tokensWithSpecial[len(tokensWithSpecial)-1], "last token should be [SEP] (102)") + s.Require().Equal(bosTokenID, tokens[0], "first token should be [CLS] (101)") + s.Require().Greater(len(tokens), 1, "should have more than just the first token") + s.Require().Equal(eosTokenID, tokens[len(tokens)-1], "last token should be [SEP] (102)") - s.T().Logf("Tokens with special: %d, without special: %d", len(tokensWithSpecial), len(tokensWithoutSpecial)) + s.T().Logf("Tokens with special: %d", len(tokens)) } -// TestRenderChatTemplate tests rendering a multi-turn conversation via the +// TestRenderChat tests rendering a multi-turn conversation via the // model's tokenizer chat template. -func (s *UDSTokenizerSuite) TestRenderChatTemplate() { +func (s *UDSTokenizerSuite) TestRenderChat() { conversation := []types.Conversation{ {Role: "user", Content: "What is machine learning?"}, {Role: "assistant", Content: "Machine learning is a subset of AI."},