Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 6 additions & 67 deletions examples/kv_cache_aware_scorer/kvcache_aware_scorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ package scorer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"time"

"github.com/llm-d/llm-d-kv-cache/pkg/kvcache"
"github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock"
"github.com/llm-d/llm-d-kv-cache/pkg/kvevents"
preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
Expand Down Expand Up @@ -216,7 +214,12 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.
return nil
}

scores, err := s.getScores(ctx, request)
tokens, err := s.kvCacheIndexer.Tokenize(nil, request.Prompt)
if err != nil {
logger.Error(err, "Failed to tokenize prompt")
return nil
}
scores, err := s.kvCacheIndexer.GetPodScores(ctx, tokens, request.TargetModel, nil)
if err != nil {
logger.Error(err, "Failed to get pod scores")
return nil
Expand Down Expand Up @@ -248,67 +251,3 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
}

// getScores retrieves the pod scores from the KV-cache indexer
// based on the provided LLM request.
// If the request contains chat completions, it processes them accordingly.
// If the request contains regular completions, it uses the prompt directly.
func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) {
logger := log.FromContext(ctx).WithName(s.typedName.String())
traceLogger := logger.V(logutil.TRACE)

traceLogger.Info("Getting scores",
"isChatCompletions", request.Body != nil && request.Body.ChatCompletions != nil,
"isCompletions", request.Body != nil && request.Body.Completions != nil)

// The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions.
// If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics.
if request.Body != nil && request.Body.ChatCompletions != nil {
if request.Body.Completions != nil {
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
}

renderReq := &preprocessing.ApplyChatTemplateRequest{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the chat template is lost with this refactoring.

Conversation: make([][]preprocessing.Conversation, 0),
Tools: request.Body.ChatCompletions.Tools,
Documents: request.Body.ChatCompletions.Documents,
ChatTemplate: request.Body.ChatCompletions.ChatTemplate,
ReturnAssistantTokensMask: request.Body.ChatCompletions.ReturnAssistantTokensMask,
ContinueFinalMessage: request.Body.ChatCompletions.ContinueFinalMessage,
AddGenerationPrompt: request.Body.ChatCompletions.AddGenerationPrompt,
ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs,
}

// Convert messages to the format expected by the renderer
for _, msg := range request.Body.ChatCompletions.Messages {
renderReq.Conversation = append(renderReq.Conversation, []preprocessing.Conversation{{
Role: msg.Role,
Content: msg.Content.Raw,
}})
}

traceLogger.Info("Processing chat completion request",
"conversationCount", len(renderReq.Conversation),
"toolsCount", len(renderReq.Tools),
"documentsCount", len(renderReq.Documents))

scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil)
if err != nil {
return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err)
}
return scores, nil
}

// For regular completions, use the prompt directly
if request.Body != nil && request.Body.Completions != nil {
prompt := request.Body.Completions.Prompt
traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt))

scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil)
if err != nil {
return nil, fmt.Errorf("failed to get pod scores for completions: %w", err)
}
return scores, nil
}

return nil, errors.New("no valid input found in request")
}
12 changes: 9 additions & 3 deletions examples/kv_cache_index/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error {
modelName := getModelName()
logger.Info("Started Indexer", "model", modelName)

// Tokenize the prompt
tokens, err := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt)
if err != nil {
return fmt.Errorf("failed to tokenize prompt: %w", err)
}

// Get pods for the prompt
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, modelName, nil)
pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, modelName, nil)
if err != nil {
return err
}
Expand All @@ -153,8 +159,8 @@ func runPrompts(ctx context.Context, kvCacheIndexer *kvcache.Indexer) error {
// Sleep 3 secs
time.Sleep(3 * time.Second)

// Get pods for the prompt
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, modelName, nil)
// Get pods for the prompt (reuse tokens from above)
pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, modelName, nil)
if err != nil {
return err
}
Expand Down
7 changes: 6 additions & 1 deletion examples/kv_cache_index_service/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ func main() {
}

// Initial query - should be empty since no events have been published
pods, err := indexerSvc.indexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
tokens, err := indexerSvc.indexer.Tokenize(testdata.RenderReq, testdata.Prompt)
if err != nil {
logger.Error(err, "failed to tokenize prompt")
return
}
pods, err := indexerSvc.indexer.GetPodScores(ctx, tokens, testdata.ModelName, nil)
if err != nil {
logger.Error(err, "failed to get pod scores")
}
Expand Down
8 changes: 7 additions & 1 deletion examples/kv_cache_index_service/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,14 @@ func (s *IndexerService) GetPodScores(ctx context.Context,
return nil, fmt.Errorf("request cannot be nil")
}

// Tokenize the prompt
tokens, err := s.indexer.Tokenize(nil, req.Prompt)
if err != nil {
return nil, fmt.Errorf("failed to tokenize prompt: %w", err)
}

// Call the underlying indexer
podScores, err := s.indexer.GetPodScores(ctx, nil, req.Prompt, req.ModelName,
podScores, err := s.indexer.GetPodScores(ctx, tokens, req.ModelName,
req.PodIdentifiers)
if err != nil {
return nil, fmt.Errorf("failed to get pod scores: %w", err)
Expand Down
17 changes: 12 additions & 5 deletions examples/kv_events/offline/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package main
import (
"context"
_ "embed"
"fmt"
"os"
"os/signal"
"syscall"
Expand Down Expand Up @@ -145,8 +146,14 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish

logger.Info("@@@ Starting KV Events Demo", "model", testdata.ModelName)

// Tokenize the prompt
tokens, err := kvCacheIndexer.Tokenize(testdata.RenderReq, testdata.Prompt)
if err != nil {
return fmt.Errorf("failed to tokenize prompt: %w", err)
}

// Initial query - should be empty since no events have been published
pods, err := kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil)
if err != nil {
return err
}
Expand All @@ -158,8 +165,8 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
return err
}

// Query again to see the effect of the events
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
// Query again to see the effect of the events (reuse tokens)
pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil)
if err != nil {
return err
}
Expand All @@ -171,8 +178,8 @@ func RunEventsDemo(ctx context.Context, kvCacheIndexer *kvcache.Indexer, publish
return err
}

// Final query
pods, err = kvCacheIndexer.GetPodScores(ctx, testdata.RenderReq, testdata.Prompt, testdata.ModelName, nil)
// Final query (reuse tokens)
pods, err = kvCacheIndexer.GetPodScores(ctx, tokens, testdata.ModelName, nil)
if err != nil {
return err
}
Expand Down
16 changes: 13 additions & 3 deletions examples/kv_events/online/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,12 @@ func setupUnifiedHTTPEndpoints(
return
}

pods, err := kvCacheIndexer.GetPodScores(ctx, nil, req.Prompt, req.Model, nil)
tokens, err := kvCacheIndexer.Tokenize(nil, req.Prompt)
if err != nil {
http.Error(w, fmt.Sprintf("failed to tokenize: %v", err), http.StatusInternalServerError)
return
}
pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil)
if err != nil {
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -337,8 +342,13 @@ func setupUnifiedHTTPEndpoints(
return
}

// Get score
pods, err := kvCacheIndexer.GetPodScores(ctx, nil, renderedPrompt, req.Model, nil)
// Tokenize and get score
tokens, err := kvCacheIndexer.Tokenize(nil, renderedPrompt)
if err != nil {
http.Error(w, fmt.Sprintf("failed to tokenize: %v", err), http.StatusInternalServerError)
return
}
pods, err := kvCacheIndexer.GetPodScores(ctx, tokens, req.Model, nil)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get score request: %v", err), http.StatusInternalServerError)
return
Expand Down
16 changes: 11 additions & 5 deletions examples/valkey_example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,14 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)

logger.Info("Processing testdata prompt", "model", modelName, "promptLength", len(prompt))

// Tokenize the prompt
tokens, err := indexer.Tokenize(nil, prompt)
if err != nil {
return fmt.Errorf("failed to tokenize prompt: %w", err)
}

// First, let's demonstrate basic scoring without any cache entries
scores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"})
scores, err := indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"})
if err != nil {
return fmt.Errorf("failed to get pod scores: %w", err)
}
Expand All @@ -159,8 +165,8 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)

logger.Info("Added cache entries", "keys", len(promptKeys), "pods", len(podEntries))

// Query for cache scores again
scores, err = indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"})
// Query for cache scores again (reuse tokens)
scores, err = indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"})
if err != nil {
return fmt.Errorf("failed to get pod scores after adding entries: %w", err)
}
Expand Down Expand Up @@ -194,8 +200,8 @@ func demonstrateValkeyOperations(ctx context.Context, indexer *kvcache.Indexer)

logger.Info("Cache lookup after eviction", "keysFound", len(lookupAfterEvict))

// Final score check to see the difference
finalScores, err := indexer.GetPodScores(ctx, nil, prompt, modelName, []string{"demo-pod-1", "demo-pod-2"})
// Final score check to see the difference (reuse tokens)
finalScores, err := indexer.GetPodScores(ctx, tokens, modelName, []string{"demo-pod-1", "demo-pod-2"})
if err != nil {
return fmt.Errorf("failed to get final pod scores: %w", err)
}
Expand Down
27 changes: 20 additions & 7 deletions pkg/kvcache/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ func (k *Indexer) KVBlockIndex() kvblock.Index {
// relevant.
//
// The function returns a map of pod identifiers to scores.
func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.ApplyChatTemplateRequest, prompt, modelName string,
func (k *Indexer) GetPodScores(
ctx context.Context,
tokens []uint32,
modelName string,
podIdentifiers []string,
) (map[string]float64, error) {
traceLogger := log.FromContext(ctx).V(logging.TRACE).WithName("kvcache.GetPodScores")

// 1. tokenize prompt
tokens := k.tokenizersPool.Tokenize(renderReq, prompt)

// 2. get block keys
// get block keys
blockKeys := k.tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, modelName)
if len(blockKeys) == 0 {
traceLogger.Info("no block keys found, returning empty scores")
Expand All @@ -146,15 +146,15 @@ func (k *Indexer) GetPodScores(ctx context.Context, renderReq *preprocessing.App

traceLogger.Info("found tokens", "tokens", tokens, "block-keys", blockKeys)

// 3. query kvblock indexer for pods
// query kvblock indexer for pods
keyToPods, err := k.kvBlockIndex.Lookup(ctx, blockKeys, sets.New(podIdentifiers...))
if err != nil {
return nil, fmt.Errorf("failed to query kvblock indexer: %w", err)
}
traceLogger.Info("found block keys", "block-keys", blockKeys,
"pods", podsPerKeyPrintHelper(keyToPods))

// 4. score pods
// score pods
podScores, err := k.kvBlockScorer.Score(blockKeys, keyToPods)
if err != nil {
return nil, fmt.Errorf("failed to query kvblock scorer: %w", err)
Expand All @@ -181,3 +181,16 @@ func podsPerKeyPrintHelper(ks map[kvblock.BlockHash][]kvblock.PodEntry) string {
func (k *Indexer) SetTokenizer(tokenizer tokenization.Tokenizer, modelName string) {
k.tokenizersPool.SetTokenizer(tokenizer, modelName)
}

// Tokenize converts a prompt string into token IDs using the appropriate tokenizer.
// It queues a tokenization task to the internal tokenizers pool and blocks until
// the result is available.
//
// The renderReq parameter provides model-specific context via its Key field, which
// is used to select the correct tokenizer for the model. If renderReq is nil, the
// prompt is tokenized directly without chat template rendering.
//
// Returns the token IDs as a uint32 slice, or an error if tokenization fails.
func (k *Indexer) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) ([]uint32, error) {
return k.tokenizersPool.Tokenize(renderReq, prompt)
}
12 changes: 8 additions & 4 deletions pkg/tokenization/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func DefaultConfig() (*Config, error) {
// tokenizationResponse holds the result of a tokenization operation.
type tokenizationResponse struct {
Tokens []uint32
Err error
}

// Task represents a unit of work for tokenizing a prompt.
Expand Down Expand Up @@ -154,7 +155,7 @@ func (pool *Pool) EnqueueTokenization(prompt string) {
}

// Tokenize queues a task and blocks until the final result is available.
func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) []uint32 {
func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, prompt string) ([]uint32, error) {
resultCh := make(chan tokenizationResponse, 1)
pool.queue.Add(Task{
RenderReq: renderReq,
Expand All @@ -163,8 +164,10 @@ func (pool *Pool) Tokenize(renderReq *preprocessing.ApplyChatTemplateRequest, pr
})

res := <-resultCh
tokens := res.Tokens
return tokens
if res.Err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think return res.Tokens, res.Err should be enough

return nil, res.Err
}
return res.Tokens, nil
}

// Run launches worker goroutines that process tasks until the context is
Expand Down Expand Up @@ -206,7 +209,8 @@ func (pool *Pool) workerLoop(_ int) {
"retries", maxRetries)
pool.queue.Forget(task)
if task.ResultCh != nil {
// Closing the channel signals failure (zero value received by caller)
// Send the error to the caller
task.ResultCh <- tokenizationResponse{Err: err}
close(task.ResultCh)
}
}
Expand Down
16 changes: 12 additions & 4 deletions pkg/tokenization/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,22 @@
},
verify: func(t *testing.T, pool *Pool, tasks []Task, resultCh chan tokenizationResponse) {
t.Helper()
require.Eventually(t, func() bool { // channel is closed, when max retries exceeded
if result, ok := <-resultCh; !ok {
assert.Equal(t, tokenizationResponse{}, result)
// When max retries exceeded, error should be sent before channel is closed
require.Eventually(t, func() bool {
if result, ok := <-resultCh; ok {
assert.NotNil(t, result.Err, "expected error in response")
assert.Nil(t, result.Tokens, "expected nil tokens on error")
return true
}
return false
}, time.Second, 10*time.Millisecond)

// Verify channel is closed after error is sent
require.Eventually(t, func() bool {
_, ok := <-resultCh
return !ok
}, time.Second, 10*time.Millisecond)

require.Eventually(t, func() bool {
return pool.queue.Len() == 0
}, time.Second, 10*time.Millisecond)
Expand Down Expand Up @@ -411,7 +419,7 @@
// Submit tokenization requests in a loop until limit
for i := 0; b.Loop(); i++ {
prompt := generateRandomSentence(benchmarkWordLength, benchmarkMaxWords, rng)
pool.Tokenize(nil, prompt)
_, _ = pool.Tokenize(nil, prompt)

Check failure on line 422 in pkg/tokenization/pool_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

Error return value of `pool.Tokenize` is not checked (errcheck)
}

b.StopTimer()
Expand Down
Loading
Loading