From c63b435e76e2f93ce77c292374499485e6ffe81e Mon Sep 17 00:00:00 2001 From: Antonio Cardace Date: Wed, 21 Jan 2026 13:44:41 +0100 Subject: [PATCH 1/2] refactor: remove tokenization from GetPodScores Change GetPodScores signature to accept tokens instead of handling tokenization internally. Signed-off-by: Antonio Cardace --- .../kvcache_aware_scorer.go | 69 +------------------ examples/kv_cache_index/main.go | 9 ++- .../kv_cache_index_service/server/main.go | 3 +- .../kv_cache_index_service/server/server.go | 5 +- examples/kv_events/offline/main.go | 13 ++-- examples/kv_events/online/main.go | 8 ++- examples/valkey_example/main.go | 13 ++-- pkg/kvcache/indexer.go | 27 ++++++-- tests/e2e/redis_mock/e2e_test.go | 58 ++++++++++------ 9 files changed, 93 insertions(+), 112 deletions(-) diff --git a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go index 30895afd..fdace8fa 100644 --- a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go +++ b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go @@ -18,7 +18,6 @@ package scorer import ( "context" "encoding/json" - "errors" "fmt" "os" "time" @@ -26,7 +25,6 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" - preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" @@ -216,7 +214,8 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. return nil } - scores, err := s.getScores(ctx, request) + tokens := s.kvCacheIndexer.Tokenize(nil, request.Prompt) + scores, err := s.kvCacheIndexer.GetPodScores(ctx, tokens, request.TargetModel, nil) if err != nil { logger.Error(err, "Failed to get pod scores") return nil @@ -248,67 +247,3 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) } -// getScores retrieves the pod scores from the KV-cache indexer -// based on the provided LLM request. -// If the request contains chat completions, it processes them accordingly. -// If the request contains regular completions, it uses the prompt directly. -func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) { - logger := log.FromContext(ctx).WithName(s.typedName.String()) - traceLogger := logger.V(logutil.TRACE) - - traceLogger.Info("Getting scores", - "isChatCompletions", request.Body != nil && request.Body.ChatCompletions != nil, - "isCompletions", request.Body != nil && request.Body.Completions != nil) - - // The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions. - // If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics. - if request.Body != nil && request.Body.ChatCompletions != nil { - if request.Body.Completions != nil { - traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions") - } - - renderReq := &preprocessing.ApplyChatTemplateRequest{ - Conversation: make([][]preprocessing.Conversation, 0), - Tools: request.Body.ChatCompletions.Tools, - Documents: request.Body.ChatCompletions.Documents, - ChatTemplate: request.Body.ChatCompletions.ChatTemplate, - ReturnAssistantTokensMask: request.Body.ChatCompletions.ReturnAssistantTokensMask, - ContinueFinalMessage: request.Body.ChatCompletions.ContinueFinalMessage, - AddGenerationPrompt: request.Body.ChatCompletions.AddGenerationPrompt, - ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs, - } - - // Convert messages to the format expected by the renderer - for _, msg := range request.Body.ChatCompletions.Messages { - renderReq.Conversation = append(renderReq.Conversation, []preprocessing.Conversation{{ - Role: msg.Role, - Content: msg.Content.Raw, - }}) - } - - traceLogger.Info("Processing chat completion request", - "conversationCount", len(renderReq.Conversation), - "toolsCount", len(renderReq.Tools), - "documentsCount", len(renderReq.Documents)) - - scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil) - if err != nil { - return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err) - } - return scores, nil - } - - // For regular completions, use the prompt directly - if request.Body != nil && request.Body.Completions != nil { - prompt := request.Body.Completions.Prompt - traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt)) - - scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil) - if err != nil { - return nil, fmt.Errorf("failed to get pod scores for completions: %w", err) - } - return scores, nil - } - - return nil, errors.New("no valid input found in request") -} diff --git a/examples/kv_cache_index/main.go b/examples/kv_cache_index/main.go index 951b4df1..6bce29b4 100644 --- a/examples/kv_cache_index/main.go +++ b/examples/kv_cache_index/main.go @@ -129,8 +129,11 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error { modelName := getModelName() logger.Info("Started Indexer", "model", modelName) + // Tokenize the prompt + tokens := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + // Get pods for the prompt - pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, modelName, nil) + pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, modelName, nil) if err != nil { return err } @@ -153,8 +156,8 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error { // Sleep 3 secs time.Sleep(3 * time.Second) - // Get pods for the prompt - pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, modelName, nil) + // Get pods for the prompt (reuse tokens from above) + pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, modelName, nil) if err != nil { return err } diff --git a/examples/kv_cache_index_service/server/main.go b/examples/kv_cache_index_service/server/main.go index 25d2c884..b774fb37 100644 --- a/examples/kv_cache_index_service/server/main.go +++ b/examples/kv_cache_index_service/server/main.go @@ -60,7 +60,8 @@ func main() { } // Initial query - should be empty since no events have been published - pods, err := indexerSvc.indexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + tokens := indexerSvc.indexer.Tokenize(testdata.RenderReq, testdata.Prompt) + pods, err := indexerSvc.indexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) if err != nil { logger.Error(err, "failed to get pod scores") } diff --git a/examples/kv_cache_index_service/server/server.go b/examples/kv_cache_index_service/server/server.go index 37b0cf84..2480d8e5 100644 --- a/examples/kv_cache_index_service/server/server.go +++ b/examples/kv_cache_index_service/server/server.go @@ -71,8 +71,11 @@ func (s *IndexerService) GetPodScores(ctx context.Context, return nil, fmt.Errorf("request cannot be nil") } + // Tokenize the prompt + tokens := s.indexer.Tokenize(nil, req.Prompt) + // Call the underlying indexer - podScores, err := s.indexer.GetPodScores(ctx, nil, req.Prompt, req.ModelName, + podScores, err := s.indexer.GetPodScores(ctx, tokens, req.ModelName, req.PodIdentifiers) if err != nil { return nil, fmt.Errorf("failed to get pod scores: %w", err) diff --git a/examples/kv_events/offline/main.go b/examples/kv_events/offline/main.go index 3c1f0651..133da3f1 100644 --- a/examples/kv_events/offline/main.go +++ b/examples/kv_events/offline/main.go @@ -145,8 +145,11 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish logger.Info("@@@ Starting KV Events Demo", "model", testdata.ModelName) + // Tokenize the prompt + tokens := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + // Initial query - should be empty since no events have been published - pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) if err != nil { return err } @@ -158,8 +161,8 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish return err } - // Query again to see the effect of the events - pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + // Query again to see the effect of the events (reuse tokens) + pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) if err != nil { return err } @@ -171,8 +174,8 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish return err } - // Final query - pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil) + // Final query (reuse tokens) + pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) if err != nil { return err } diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index 8fd14a1f..9578ff9d 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -297,7 +297,8 @@ func setupUnifiedHTTPEndpoints( return } - pods, err := kvCacheIndexer.GetPodScores(ctx, nil, req.Prompt, req.Model, nil) + tokens := kvCacheIndexer.Tokenize(nil, req.Prompt) + pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil) if err != nil { http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) return @@ -337,8 +338,9 @@ func setupUnifiedHTTPEndpoints( return } - // Get score - pods, err := kvCacheIndexer.GetPodScores(ctx, nil, renderedPrompt, req.Model, nil) + // Tokenize and get score + tokens := kvCacheIndexer.Tokenize(nil, renderedPrompt) + pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil) if err != nil { http.Error(w, fmt.Sprintf("Failed to get score request: %v", err), http.StatusInternalServerError) return diff --git a/examples/valkey_example/main.go b/examples/valkey_example/main.go index b7dfec3c..9ef944cf 100644 --- a/examples/valkey_example/main.go +++ b/examples/valkey_example/main.go @@ -133,8 +133,11 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer) logger.Info("Processing testdata prompt", "model", modelName, "promptLength", len(prompt)) + // Tokenize the prompt + tokens := indexer.Tokenize(nil, prompt) + // First, let's demonstrate basic scoring without any cache entries - scores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"}) + scores, err := indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"}) if err != nil { return fmt.Errorf("failed to get pod scores: %w", err) } @@ -159,8 +162,8 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer) logger.Info("Added cache entries", "keys", len(promptKeys), "pods", len(podEntries)) - // Query for cache scores again - scores, err = indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"}) + // Query for cache scores again (reuse tokens) + scores, err = indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"}) if err != nil { return fmt.Errorf("failed to get pod scores after adding entries: %w", err) } @@ -194,8 +197,8 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer) logger.Info("Cache lookup after eviction", "keysFound", len(lookupAfterEvict)) - // Final score check to see the difference - finalScores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"}) + // Final score check to see the difference (reuse tokens) + finalScores, err := indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"}) if err != nil { return fmt.Errorf("failed to get final pod scores: %w", err) } diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index 7f95a5da..ef5f77dd 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -128,15 +128,15 @@ func (k *Indexer) KVBlockIndex() kvblock.Index { // relevant. // // The function returns a map of pod identifiers to scores. -func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.ApplyChatTemplateRequest, prompt, modelName string, +func (k *Indexer) GetPodScores( + ctx context.Context, + tokens []uint32, + modelName string, podIdentifiers []string, ) (map[string]float64, error) { traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores") - // 1. tokenize prompt - tokens := k.tokenizersPool.Tokenize(renderReq, prompt) - - // 2. get block keys + // get block keys blockKeys := k.tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, modelName) if len(blockKeys) == 0 { traceLogger.Info("no block keys found, returning empty scores") @@ -146,7 +146,7 @@ func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.App traceLogger.Info("found tokens", "tokens", tokens, "block-keys", blockKeys) - // 3. query kvblock indexer for pods + // query kvblock indexer for pods keyToPods, err := k.kvBlockIndex.Lookup(ctx, blockKeys, sets.New(podIdentifiers...)) if err != nil { return nil, fmt.Errorf("failed to query kvblock indexer: %w", err) @@ -154,7 +154,7 @@ func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.App traceLogger.Info("found block keys", "block-keys", blockKeys, "pods", podsPerKeyPrintHelper(keyToPods)) - // 4. score pods + // score pods podScores, err := k.kvBlockScorer.Score(blockKeys, keyToPods) if err != nil { return nil, fmt.Errorf("failed to query kvblock scorer: %w", err) @@ -181,3 +181,16 @@ func podsPerKeyPrintHelper(ks map[kvblock.BlockHash][]kvblock.PodEntry) string { func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName string) { k.tokenizersPool.SetTokenizer(tokenizer, modelName) } + +// Tokenize converts a prompt string into token IDs using the appropriate tokenizer. +// It queues a tokenization task to the internal tokenizers pool and blocks until +// the result is available. +// +// The renderReq parameter provides model-specific context via its Key field, which +// is used to select the correct tokenizer for the model. If renderReq is nil, the +// prompt is tokenized directly without chat template rendering. +// +// Returns the token IDs as a uint32 slice, or an error if tokenization fails. +func (k *Indexer) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) []uint32 { + return k.tokenizersPool.Tokenize(renderReq, prompt) +} diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index 50fc7836..e88ece58 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -121,7 +121,8 @@ func (s *KVCacheSuite) TestCacheHit() { engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName) s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) - pods, err := s.indexer.GetPodScores(s.ctx, nil, prompt, defaultModelName, fakePodList) + tokens := s.indexer.Tokenize(nil, prompt) + pods, err := s.indexer.GetPodScores(s.ctx, tokens, defaultModelName, fakePodList) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) s.Len(pods, len(fakePodList), "expected pod scores length to match candidate pods") @@ -132,7 +133,8 @@ func (s *KVCacheSuite) TestCacheMiss() { prompt := "What is the capital of France?" fakePodList := []string{s.Pod1IP} - pods, err := s.indexer.GetPodScores(s.ctx, nil, prompt, defaultModelName, fakePodList) + tokens := s.indexer.Tokenize(nil, prompt) + pods, err := s.indexer.GetPodScores(s.ctx, tokens, defaultModelName, fakePodList) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores since no keys were added to the index") @@ -150,7 +152,8 @@ func (s *KVCacheSuite) TestPrefixReduction() { fakePodList := []string{s.Pod1IP} // Test 1: Full prompt (no match expected) - pods, err := s.indexer.GetPodScores(s.ctx, nil, fullPrompt, defaultModelName, []string{s.Pod1IP}) + fullTokens := s.indexer.Tokenize(nil, fullPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, fullTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores") @@ -158,14 +161,16 @@ func (s *KVCacheSuite) TestPrefixReduction() { s.addEntriesToIndex(fullPromptEngineKeys, fullPromptRequestKeys, fakePodList) // Test 2: mid-length prompt(should return a match) - pods, err = s.indexer.GetPodScores(s.ctx, nil, midPrompt, defaultModelName, []string{s.Pod1IP}) + midTokens := s.indexer.Tokenize(nil, midPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, midTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) s.Greater(int(pods[s.Pod1IP]), 0, "mid-prompt block keys should have been indexed") // Test 3: short prompt(should return a match) - pods, err = s.indexer.GetPodScores(s.ctx, nil, shortPrompt, defaultModelName, []string{s.Pod1IP}) + shortTokens := s.indexer.Tokenize(nil, shortPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, shortTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) s.Len(pods, len(fakePodList), "expected pod scores length to match candidate pods") @@ -185,7 +190,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { fakePodList := []string{s.Pod1IP} // Test 1: short prompt - pods, err := s.indexer.GetPodScores(s.ctx, nil, shortPrompt, modelName, []string{s.Pod1IP}) + shortTokens := s.indexer.Tokenize(nil, shortPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, shortTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores") @@ -194,7 +200,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt - pods, err = s.indexer.GetPodScores(s.ctx, nil, midPrompt, modelName, []string{s.Pod1IP}) + midTokens := s.indexer.Tokenize(nil, midPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, midTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -204,7 +211,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: full prompt - pods, err = s.indexer.GetPodScores(s.ctx, nil, fullPrompt, modelName, []string{s.Pod1IP}) + fullTokens := s.indexer.Tokenize(nil, fullPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, fullTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -224,7 +232,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { fakePodList := []string{s.Pod1IP} // Test 1: short prompt (should return no pod scores yet) - pods, err := s.indexer.GetPodScores(s.ctx, nil, shortPrompt, modelName, []string{s.Pod1IP}) + shortTokens := s.indexer.Tokenize(nil, shortPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, shortTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Short prompt scores: %+v", pods) s.Empty(pods, "expected no pod scores") @@ -234,7 +243,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt (should return partial match if indexer picks it up) - pods, err = s.indexer.GetPodScores(s.ctx, nil, midPrompt, modelName, []string{s.Pod1IP}) + midTokens := s.indexer.Tokenize(nil, midPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, midTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Mid prompt scores: %+v", pods) s.True(len(pods) > 0, "expected at least one pod score for mid prompt") @@ -244,7 +254,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: long prompt (should return higher score) - pods, err = s.indexer.GetPodScores(s.ctx, nil, longPrompt, modelName, []string{s.Pod1IP}) + longTokens := s.indexer.Tokenize(nil, longPrompt) + pods, err = s.indexer.GetPodScores(s.ctx, longTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Long prompt scores: %+v", pods) s.True(len(pods) > 0, "expected at least one pod score for long prompt") @@ -292,7 +303,8 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { fakePodList := []string{s.Pod1IP} // First lookup - should return no scores initially. - pods, err := s.indexer.GetPodScores(s.ctx, nil, flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) + tokens := s.indexer.Tokenize(nil, flattenedPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("First lookup - Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores on first lookup") @@ -301,7 +313,7 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Second lookup - should return scores. - pods, err = s.indexer.GetPodScores(s.ctx, nil, flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) + pods, err = s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Second lookup - Received pod scores: %+v", pods) s.Len(pods, 1, "expected one pod score") @@ -366,7 +378,8 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { fakePodList := []string{s.Pod1IP} // First lookup. - pods, err := s.indexer.GetPodScores(s.ctx, nil, flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) + tokens := s.indexer.Tokenize(nil, flattenedPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("First lookup - Received pod scores: %+v", pods) s.Empty(pods, "expected no pod scores on first lookup") @@ -375,7 +388,7 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Second lookup. - pods, err = s.indexer.GetPodScores(s.ctx, nil, flattenedPrompt, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) + pods, err = s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Second lookup - Received pod scores: %+v", pods) s.Len(pods, 1, "expected one pod score") @@ -415,7 +428,8 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify that we can retrieve the entries we just added using GetPodScores - pods, err := s.indexer.GetPodScores(s.ctx, nil, prompt, modelName, fakePodList) + promptTokens := s.indexer.Tokenize(nil, prompt) + pods, err := s.indexer.GetPodScores(s.ctx, promptTokens, modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") s.Require().Greater(pods[s.Pod1IP], float64(0), "expected positive pod score") @@ -555,7 +569,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify retrieval using GetPodScores with the rendered prompt - pods, err := s.indexer.GetPodScores(s.ctx, nil, renderedPrompt, tc.modelName, fakePodList) + renderedTokens := s.indexer.Tokenize(nil, renderedPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, renderedTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") s.Require().Greater(pods[s.Pod1IP], float64(0), "expected positive pod score") @@ -688,7 +703,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { s.addEntriesToIndex(extendedEngineKeys, extendedRequestKeys, fakePodList) // Verify that querying with the short conversation still works (prefix sharing in KV-cache) - pods, err := s.indexer.GetPodScores(s.ctx, nil, shortPrompt, tc.modelName, fakePodList) + shortPromptTokens := s.indexer.Tokenize(nil, shortPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, shortPromptTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "Short conversation should still match after adding extended conversation") s.T().Logf("Short conversation match score: %v", pods[s.Pod1IP]) @@ -766,7 +782,8 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) - pods, err := s.indexer.GetPodScores(s.ctx, nil, localRendered, tc.modelName, fakePodList) + localRenderedTokens := s.indexer.Tokenize(nil, localRendered) + pods, err := s.indexer.GetPodScores(s.ctx, localRenderedTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") s.Require().Greater(pods[s.Pod1IP], float64(0), "expected positive pod score") @@ -909,7 +926,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify retrieval using GetPodScores // Note: This works now because the test suite uses a composite tokenizer that includes the local models - pods, err := s.indexer.GetPodScores(s.ctx, nil, renderedPrompt, tc.modelName, fakePodList) + promptTokens := s.indexer.Tokenize(nil, renderedPrompt) + pods, err := s.indexer.GetPodScores(s.ctx, promptTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") s.Require().Greater(pods[s.Pod1IP], float64(0), "expected positive pod score") From cc03a33a2dc83aff38e24426d4575ec88de7b918 Mon Sep 17 00:00:00 2001 From: Antonio Cardace Date: Wed, 21 Jan 2026 13:58:50 +0100 Subject: [PATCH 2/2] Add error handling to Tokenize method Signed-off-by: Antonio Cardace --- .../kvcache_aware_scorer.go | 6 ++- examples/kv_cache_index/main.go | 5 +- .../kv_cache_index_service/server/main.go | 6 ++- .../kv_cache_index_service/server/server.go | 5 +- examples/kv_events/offline/main.go | 6 ++- examples/kv_events/online/main.go | 12 ++++- examples/valkey_example/main.go | 5 +- pkg/kvcache/indexer.go | 2 +- pkg/tokenization/pool.go | 12 +++-- pkg/tokenization/pool_test.go | 16 ++++-- tests/e2e/redis_mock/e2e_test.go | 54 ++++++++++++------- 11 files changed, 94 insertions(+), 35 deletions(-) diff --git a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go index fdace8fa..83b801a2 100644 --- a/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go +++ b/examples/kv_cache_aware_scorer/kvcache_aware_scorer.go @@ -214,7 +214,11 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. return nil } - tokens := s.kvCacheIndexer.Tokenize(nil, request.Prompt) + tokens, err := s.kvCacheIndexer.Tokenize(nil, request.Prompt) + if err != nil { + logger.Error(err, "Failed to tokenize prompt") + return nil + } scores, err := s.kvCacheIndexer.GetPodScores(ctx, tokens, request.TargetModel, nil) if err != nil { logger.Error(err, "Failed to get pod scores") diff --git a/examples/kv_cache_index/main.go b/examples/kv_cache_index/main.go index 6bce29b4..d792d6b7 100644 --- a/examples/kv_cache_index/main.go +++ b/examples/kv_cache_index/main.go @@ -130,7 +130,10 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error { logger.Info("Started Indexer", "model", modelName) // Tokenize the prompt - tokens := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + tokens, err := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + if err != nil { + return fmt.Errorf("failed to tokenize prompt: %w", err) + } // Get pods for the prompt pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, modelName, nil) diff --git a/examples/kv_cache_index_service/server/main.go b/examples/kv_cache_index_service/server/main.go index b774fb37..8ab096f4 100644 --- a/examples/kv_cache_index_service/server/main.go +++ b/examples/kv_cache_index_service/server/main.go @@ -60,7 +60,11 @@ func main() { } // Initial query - should be empty since no events have been published - tokens := indexerSvc.indexer.Tokenize(testdata.RenderReq, testdata.Prompt) + tokens, err := indexerSvc.indexer.Tokenize(testdata.RenderReq, testdata.Prompt) + if err != nil { + logger.Error(err, "failed to tokenize prompt") + return + } pods, err := indexerSvc.indexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) if err != nil { logger.Error(err, "failed to get pod scores") diff --git a/examples/kv_cache_index_service/server/server.go b/examples/kv_cache_index_service/server/server.go index 2480d8e5..c5cc1338 100644 --- a/examples/kv_cache_index_service/server/server.go +++ b/examples/kv_cache_index_service/server/server.go @@ -72,7 +72,10 @@ func (s *IndexerService) GetPodScores(ctx context.Context, } // Tokenize the prompt - tokens := s.indexer.Tokenize(nil, req.Prompt) + tokens, err := s.indexer.Tokenize(nil, req.Prompt) + if err != nil { + return nil, fmt.Errorf("failed to tokenize prompt: %w", err) + } // Call the underlying indexer podScores, err := s.indexer.GetPodScores(ctx, tokens, req.ModelName, diff --git a/examples/kv_events/offline/main.go b/examples/kv_events/offline/main.go index 133da3f1..0cbcb649 100644 --- a/examples/kv_events/offline/main.go +++ b/examples/kv_events/offline/main.go @@ -19,6 +19,7 @@ package main import ( "context" _ "embed" + "fmt" "os" "os/signal" "syscall" @@ -146,7 +147,10 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish logger.Info("@@@ Starting KV Events Demo", "model", testdata.ModelName) // Tokenize the prompt - tokens := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + tokens, err := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt) + if err != nil { + return fmt.Errorf("failed to tokenize prompt: %w", err) + } // Initial query - should be empty since no events have been published pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil) diff --git a/examples/kv_events/online/main.go b/examples/kv_events/online/main.go index 9578ff9d..31b98a1e 100644 --- a/examples/kv_events/online/main.go +++ b/examples/kv_events/online/main.go @@ -297,7 +297,11 @@ func setupUnifiedHTTPEndpoints( return } - tokens := kvCacheIndexer.Tokenize(nil, req.Prompt) + tokens, err := kvCacheIndexer.Tokenize(nil, req.Prompt) + if err != nil { + http.Error(w, fmt.Sprintf("failed to tokenize: %v", err), http.StatusInternalServerError) + return + } pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil) if err != nil { http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) @@ -339,7 +343,11 @@ func setupUnifiedHTTPEndpoints( } // Tokenize and get score - tokens := kvCacheIndexer.Tokenize(nil, renderedPrompt) + tokens, err := kvCacheIndexer.Tokenize(nil, renderedPrompt) + if err != nil { + http.Error(w, fmt.Sprintf("failed to tokenize: %v", err), http.StatusInternalServerError) + return + } pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil) if err != nil { http.Error(w, fmt.Sprintf("Failed to get score request: %v", err), http.StatusInternalServerError) diff --git a/examples/valkey_example/main.go b/examples/valkey_example/main.go index 9ef944cf..ad7e1500 100644 --- a/examples/valkey_example/main.go +++ b/examples/valkey_example/main.go @@ -134,7 +134,10 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer) logger.Info("Processing testdata prompt", "model", modelName, "promptLength", len(prompt)) // Tokenize the prompt - tokens := indexer.Tokenize(nil, prompt) + tokens, err := indexer.Tokenize(nil, prompt) + if err != nil { + return fmt.Errorf("failed to tokenize prompt: %w", err) + } // First, let's demonstrate basic scoring without any cache entries scores, err := indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"}) diff --git a/pkg/kvcache/indexer.go b/pkg/kvcache/indexer.go index ef5f77dd..c5f34c7f 100644 --- a/pkg/kvcache/indexer.go +++ b/pkg/kvcache/indexer.go @@ -191,6 +191,6 @@ func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName strin // prompt is tokenized directly without chat template rendering. // // Returns the token IDs as a uint32 slice, or an error if tokenization fails. -func (k *Indexer) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) []uint32 { +func (k *Indexer) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) ([]uint32, error) { return k.tokenizersPool.Tokenize(renderReq, prompt) } diff --git a/pkg/tokenization/pool.go b/pkg/tokenization/pool.go index c6b17868..c0e45988 100644 --- a/pkg/tokenization/pool.go +++ b/pkg/tokenization/pool.go @@ -65,6 +65,7 @@ func DefaultConfig() (*Config, error) { // tokenizationResponse holds the result of a tokenization operation. type tokenizationResponse struct { Tokens []uint32 + Err error } // Task represents a unit of work for tokenizing a prompt. @@ -154,7 +155,7 @@ func (pool *Pool) EnqueueTokenization(prompt string) { } // Tokenize queues a task and blocks until the final result is available. -func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) []uint32 { +func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) ([]uint32, error) { resultCh := make(chan tokenizationResponse, 1) pool.queue.Add(Task{ RenderReq: renderReq, @@ -163,8 +164,10 @@ func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, pr }) res := <-resultCh - tokens := res.Tokens - return tokens + if res.Err != nil { + return nil, res.Err + } + return res.Tokens, nil } // Run launches worker goroutines that process tasks until the context is @@ -206,7 +209,8 @@ func (pool *Pool) workerLoop(_ int) { "retries", maxRetries) pool.queue.Forget(task) if task.ResultCh != nil { - // Closing the channel signals failure (zero value received by caller) + // Send the error to the caller + task.ResultCh <- tokenizationResponse{Err: err} close(task.ResultCh) } } diff --git a/pkg/tokenization/pool_test.go b/pkg/tokenization/pool_test.go index 24177f16..f9120f23 100644 --- a/pkg/tokenization/pool_test.go +++ b/pkg/tokenization/pool_test.go @@ -214,14 +214,22 @@ func TestPool_WorkerLoop(t *testing.T) { }, verify: func(t *testing.T, pool *Pool, tasks []Task, resultCh chan tokenizationResponse) { t.Helper() - require.Eventually(t, func() bool { // channel is closed, when max retries exceeded - if result, ok := <-resultCh; !ok { - assert.Equal(t, tokenizationResponse{}, result) + // When max retries exceeded, error should be sent before channel is closed + require.Eventually(t, func() bool { + if result, ok := <-resultCh; ok { + assert.NotNil(t, result.Err, "expected error in response") + assert.Nil(t, result.Tokens, "expected nil tokens on error") return true } return false }, time.Second, 10*time.Millisecond) + // Verify channel is closed after error is sent + require.Eventually(t, func() bool { + _, ok := <-resultCh + return !ok + }, time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { return pool.queue.Len() == 0 }, time.Second, 10*time.Millisecond) @@ -411,7 +419,7 @@ func BenchmarkSyncTokenizationStress(b *testing.B) { // Submit tokenization requests in a loop until limit for i := 0; b.Loop(); i++ { prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng) - pool.Tokenize(nil, prompt) + _, _ = pool.Tokenize(nil, prompt) } b.StopTimer() diff --git a/tests/e2e/redis_mock/e2e_test.go b/tests/e2e/redis_mock/e2e_test.go index e88ece58..03dd2ac8 100644 --- a/tests/e2e/redis_mock/e2e_test.go +++ b/tests/e2e/redis_mock/e2e_test.go @@ -121,7 +121,8 @@ func (s *KVCacheSuite) TestCacheHit() { engineKeys, requestKeys := s.promptToEngineAndRequestKeys(prompt, defaultModelName) s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) - tokens := s.indexer.Tokenize(nil, prompt) + tokens, err := s.indexer.Tokenize(nil, prompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, tokens, defaultModelName, fakePodList) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -133,7 +134,8 @@ func (s *KVCacheSuite) TestCacheMiss() { prompt := "What is the capital of France?" fakePodList := []string{s.Pod1IP} - tokens := s.indexer.Tokenize(nil, prompt) + tokens, err := s.indexer.Tokenize(nil, prompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, tokens, defaultModelName, fakePodList) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -152,7 +154,8 @@ func (s *KVCacheSuite) TestPrefixReduction() { fakePodList := []string{s.Pod1IP} // Test 1: Full prompt (no match expected) - fullTokens := s.indexer.Tokenize(nil, fullPrompt) + fullTokens, err := s.indexer.Tokenize(nil, fullPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, fullTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -161,7 +164,8 @@ func (s *KVCacheSuite) TestPrefixReduction() { s.addEntriesToIndex(fullPromptEngineKeys, fullPromptRequestKeys, fakePodList) // Test 2: mid-length prompt(should return a match) - midTokens := s.indexer.Tokenize(nil, midPrompt) + midTokens, err := s.indexer.Tokenize(nil, midPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, midTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) @@ -169,7 +173,8 @@ func (s *KVCacheSuite) TestPrefixReduction() { s.Greater(int(pods[s.Pod1IP]), 0, "mid-prompt block keys should have been indexed") // Test 3: short prompt(should return a match) - shortTokens := s.indexer.Tokenize(nil, shortPrompt) + shortTokens, err := s.indexer.Tokenize(nil, shortPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, shortTokens, defaultModelName, []string{s.Pod1IP}) s.Require().NoError(err) @@ -190,7 +195,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { fakePodList := []string{s.Pod1IP} // Test 1: short prompt - shortTokens := s.indexer.Tokenize(nil, shortPrompt) + shortTokens, err := s.indexer.Tokenize(nil, shortPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, shortTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Received pod scores: %+v", pods) @@ -200,7 +206,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt - midTokens := s.indexer.Tokenize(nil, midPrompt) + midTokens, err := s.indexer.Tokenize(nil, midPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, midTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) @@ -211,7 +218,8 @@ func (s *KVCacheSuite) TestPrefixExpansion() { s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: full prompt - fullTokens := s.indexer.Tokenize(nil, fullPrompt) + fullTokens, err := s.indexer.Tokenize(nil, fullPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, fullTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) @@ -232,7 +240,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { fakePodList := []string{s.Pod1IP} // Test 1: short prompt (should return no pod scores yet) - shortTokens := s.indexer.Tokenize(nil, shortPrompt) + shortTokens, err := s.indexer.Tokenize(nil, shortPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, shortTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Short prompt scores: %+v", pods) @@ -243,7 +252,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.addEntriesToIndex(shortPromptEngineKeys, shortPromptRequestKeys, fakePodList) // Test 2: mid prompt (should return partial match if indexer picks it up) - midTokens := s.indexer.Tokenize(nil, midPrompt) + midTokens, err := s.indexer.Tokenize(nil, midPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, midTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Mid prompt scores: %+v", pods) @@ -254,7 +264,8 @@ func (s *KVCacheSuite) TestLongPrefixExpansion() { s.addEntriesToIndex(midPromptEngineKeys, midPromptRequestKeys, fakePodList) // Test 3: long prompt (should return higher score) - longTokens := s.indexer.Tokenize(nil, longPrompt) + longTokens, err := s.indexer.Tokenize(nil, longPrompt) + s.Require().NoError(err) pods, err = s.indexer.GetPodScores(s.ctx, longTokens, modelName, []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("Long prompt scores: %+v", pods) @@ -303,7 +314,8 @@ func (s *KVCacheSuite) TestChatCompletionsE2E() { fakePodList := []string{s.Pod1IP} // First lookup - should return no scores initially. - tokens := s.indexer.Tokenize(nil, flattenedPrompt) + tokens, err := s.indexer.Tokenize(nil, flattenedPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("First lookup - Received pod scores: %+v", pods) @@ -378,7 +390,8 @@ func (s *KVCacheSuite) TestLongChatCompletionsE2E() { fakePodList := []string{s.Pod1IP} // First lookup. - tokens := s.indexer.Tokenize(nil, flattenedPrompt) + tokens, err := s.indexer.Tokenize(nil, flattenedPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, tokens, "ibm-granite/granite-3.3-8b-instruct", []string{s.Pod1IP}) s.Require().NoError(err) s.T().Logf("First lookup - Received pod scores: %+v", pods) @@ -428,7 +441,8 @@ func (s *KVCacheSuite) TestCacheHitWithLocalTokenizer() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify that we can retrieve the entries we just added using GetPodScores - promptTokens := s.indexer.Tokenize(nil, prompt) + promptTokens, err := s.indexer.Tokenize(nil, prompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, promptTokens, modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") @@ -569,7 +583,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateE2E() { fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify retrieval using GetPodScores with the rendered prompt - renderedTokens := s.indexer.Tokenize(nil, renderedPrompt) + renderedTokens, err := s.indexer.Tokenize(nil, renderedPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, renderedTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") @@ -703,7 +718,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateMultiTurnE2E() { s.addEntriesToIndex(extendedEngineKeys, extendedRequestKeys, fakePodList) // Verify that querying with the short conversation still works (prefix sharing in KV-cache) - shortPromptTokens := s.indexer.Tokenize(nil, shortPrompt) + shortPromptTokens, err := s.indexer.Tokenize(nil, shortPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, shortPromptTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "Short conversation should still match after adding extended conversation") @@ -782,7 +798,8 @@ func (s *KVCacheSuite) TestLocalVsHFChatTemplateConsistency() { fakePodList := []string{s.Pod1IP} s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) - localRenderedTokens := s.indexer.Tokenize(nil, localRendered) + localRenderedTokens, err := s.indexer.Tokenize(nil, localRendered) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, localRenderedTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries") @@ -926,7 +943,8 @@ func (s *KVCacheSuite) TestLocalTokenizerChatTemplateLongConversation() { s.addEntriesToIndex(engineKeys, requestKeys, fakePodList) // Verify retrieval using GetPodScores // Note: This works now because the test suite uses a composite tokenizer that includes the local models - promptTokens := s.indexer.Tokenize(nil, renderedPrompt) + promptTokens, err := s.indexer.Tokenize(nil, renderedPrompt) + s.Require().NoError(err) pods, err := s.indexer.GetPodScores(s.ctx, promptTokens, tc.modelName, fakePodList) s.Require().NoError(err) s.Require().NotEmpty(pods, "should find pod scores after adding entries")