Skip to content

Commit 2bedecf

Browse files
authored
Merge pull request #629 from doringeman/preload-optimization
feat: preload models on configure and in interactive mode
2 parents e9d05d6 + bdf158b commit 2bedecf

File tree

3 files changed

+92
-6
lines changed

3 files changed

+92
-6
lines changed

cmd/cli/commands/run.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,11 +743,7 @@ func newRunCmd() *cobra.Command {
743743

744744
// Handle --detach flag: just load the model without interaction
745745
if detach {
746-
// Make a minimal request to load the model into memory
747-
err := desktopClient.Chat(model, "", nil, func(content string) {
748-
// Silently discard output in detach mode
749-
}, false)
750-
if err != nil {
746+
if err := desktopClient.Preload(cmd.Context(), model); err != nil {
751747
return handleClientError(err, "Failed to load model")
752748
}
753749
if debug {
@@ -764,6 +760,14 @@ func newRunCmd() *cobra.Command {
764760
return nil
765761
}
766762

763+
// For interactive mode, eagerly load the model in the background
764+
// while the user types their first query
765+
go func() {
766+
if err := desktopClient.Preload(cmd.Context(), model); err != nil {
767+
cmd.PrintErrf("background model preload failed: %v\n", err)
768+
}
769+
}()
770+
767771
// Initialize termenv with color caching before starting interactive session.
768772
// This queries the terminal background color once and caches it, preventing
769773
// OSC response sequences from appearing in stdin during the interactive loop.

cmd/cli/desktop/desktop.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,45 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func(
350350
return c.ChatWithContext(context.Background(), model, prompt, imageURLs, outputFunc, shouldUseMarkdown)
351351
}
352352

353+
// Preload loads a model into memory without running inference.
354+
// The model stays loaded for the idle timeout period.
355+
func (c *Client) Preload(ctx context.Context, model string) error {
356+
reqBody := OpenAIChatRequest{
357+
Model: model,
358+
Messages: []OpenAIChatMessage{},
359+
}
360+
361+
jsonData, err := json.Marshal(reqBody)
362+
if err != nil {
363+
return fmt.Errorf("error marshaling request: %w", err)
364+
}
365+
366+
completionsPath := c.modelRunner.OpenAIPathPrefix() + "/chat/completions"
367+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.modelRunner.URL(completionsPath), bytes.NewReader(jsonData))
368+
if err != nil {
369+
return fmt.Errorf("error creating request: %w", err)
370+
}
371+
req.Header.Set("Content-Type", "application/json")
372+
req.Header.Set("User-Agent", "docker-model-cli/"+Version)
373+
req.Header.Set("X-Preload-Only", "true")
374+
375+
resp, err := c.modelRunner.Client().Do(req)
376+
if err != nil {
377+
return c.handleQueryError(err, completionsPath)
378+
}
379+
defer resp.Body.Close()
380+
381+
if resp.StatusCode != http.StatusOK {
382+
body, err := io.ReadAll(resp.Body)
383+
if err != nil {
384+
return fmt.Errorf("preload failed with status %d and could not read response body: %w", resp.StatusCode, err)
385+
}
386+
return fmt.Errorf("preload failed: status=%d body=%s", resp.StatusCode, body)
387+
}
388+
389+
return nil
390+
}
391+
353392
// ChatWithMessagesContext performs a chat request with conversation history and returns the assistant's response.
354393
// This allows maintaining conversation context across multiple exchanges.
355394
func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conversationHistory []OpenAIChatMessage, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) (string, error) {

pkg/inference/scheduling/http_handler.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88
"fmt"
99
"io"
1010
"net/http"
11+
"net/http/httptest"
1112
"strings"
1213
"sync"
14+
"time"
1315

1416
"github.com/docker/model-runner/pkg/distribution/distribution"
1517
"github.com/docker/model-runner/pkg/inference"
@@ -19,6 +21,10 @@ import (
1921
"github.com/docker/model-runner/pkg/middleware"
2022
)
2123

24+
type contextKey bool
25+
26+
const preloadOnlyKey contextKey = false
27+
2228
// HTTPHandler handles HTTP requests for the scheduler.
2329
// It wraps the Scheduler to provide HTTP endpoint functionality without
2430
// coupling the core scheduling logic to HTTP concerns.
@@ -223,6 +229,12 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque
223229
}
224230
defer h.scheduler.loader.release(runner)
225231

232+
// If this is a preload-only request, return here without running inference.
233+
// Can be triggered via context (internal) or X-Preload-Only header (external).
234+
if r.Context().Value(preloadOnlyKey) != nil || r.Header.Get("X-Preload-Only") == "true" {
235+
return
236+
}
237+
226238
// Record the request in the OpenAI recorder.
227239
recordID := h.scheduler.openAIRecorder.RecordRequest(request.Model, r, body)
228240
w = h.scheduler.openAIRecorder.NewResponseRecorder(w)
@@ -357,7 +369,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
357369
return
358370
}
359371

360-
_, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
372+
backend, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
361373
if err != nil {
362374
if errors.Is(err, errRunnerAlreadyActive) {
363375
http.Error(w, err.Error(), http.StatusConflict)
@@ -367,6 +379,37 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
367379
return
368380
}
369381

382+
// Preload the model in the background by calling handleOpenAIInference with preload-only context.
383+
// This makes Compose preload the model as well as it calls `configure` by default.
384+
go func() {
385+
preloadBody, err := json.Marshal(OpenAIInferenceRequest{Model: configureRequest.Model})
386+
if err != nil {
387+
h.scheduler.log.Warnf("failed to marshal preload request body: %v", err)
388+
return
389+
}
390+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
391+
defer cancel()
392+
preloadReq, err := http.NewRequestWithContext(
393+
context.WithValue(ctx, preloadOnlyKey, true),
394+
http.MethodPost,
395+
inference.InferencePrefix+"/v1/chat/completions",
396+
bytes.NewReader(preloadBody),
397+
)
398+
if err != nil {
399+
h.scheduler.log.Warnf("failed to create preload request: %v", err)
400+
return
401+
}
402+
preloadReq.Header.Set("User-Agent", r.UserAgent())
403+
if backend != nil {
404+
preloadReq.SetPathValue("backend", backend.Name())
405+
}
406+
recorder := httptest.NewRecorder()
407+
h.handleOpenAIInference(recorder, preloadReq)
408+
if recorder.Code != http.StatusOK {
409+
h.scheduler.log.Warnf("background model preload failed with status %d: %s", recorder.Code, recorder.Body.String())
410+
}
411+
}()
412+
370413
w.WriteHeader(http.StatusAccepted)
371414
}
372415

0 commit comments

Comments
 (0)