Skip to content

Commit 3367f8f

Browse files
committed
feat: add per-model keep_alive configuration for idle eviction
Allow developers to control how long a model stays loaded in memory before being evicted, following Ollama API semantics. Supports duration strings (5m, 1h), 0 for immediate unload, and -1 for never. Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent aa6b09e commit 3367f8f

File tree

14 files changed

+669
-86
lines changed

14 files changed

+669
-86
lines changed

cmd/cli/commands/configure.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func newConfigureCmd() *cobra.Command {
1111
var flags ConfigureFlags
1212

1313
c := &cobra.Command{
14-
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
14+
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] [--keep-alive=<duration>] MODEL [-- <runtime-flags...>]",
1515
Aliases: []string{"config"},
1616
Short: "Manage model runtime configurations",
1717
Hidden: true,

cmd/cli/commands/configure_flags.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ type ConfigureFlags struct {
133133
// vLLM-specific flags
134134
HFOverrides string
135135
GPUMemoryUtilization *float64
136-
// Think parameter for reasoning models
137-
Think *bool
136+
Think *bool
137+
KeepAlive string
138138
}
139139

140140
// RegisterFlags registers all configuration flags on the given cobra command.
@@ -147,6 +147,7 @@ func (f *ConfigureFlags) RegisterFlags(cmd *cobra.Command) {
147147
cmd.Flags().Var(NewFloat64PtrValue(&f.GPUMemoryUtilization), "gpu-memory-utilization", "fraction of GPU memory to use for the model executor (0.0-1.0) - vLLM only")
148148
cmd.Flags().Var(NewBoolPtrValue(&f.Think), "think", "enable reasoning mode for thinking models")
149149
cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking, image-generation)")
150+
cmd.Flags().StringVar(&f.KeepAlive, "keep-alive", "", "duration to keep model loaded (e.g., '5m', '1h', '0' to unload immediately, '-1' to never unload)")
150151
}
151152

152153
// BuildConfigureRequest builds a scheduling.ConfigureRequest from the flags.
@@ -205,6 +206,14 @@ func (f *ConfigureFlags) BuildConfigureRequest(model string) (scheduling.Configu
205206
req.LlamaCpp.ReasoningBudget = reasoningBudget
206207
}
207208

209+
if f.KeepAlive != "" {
210+
ka, err := inference.ParseKeepAlive(f.KeepAlive)
211+
if err != nil {
212+
return req, err
213+
}
214+
req.KeepAlive = &ka
215+
}
216+
208217
// Parse mode if provided
209218
if f.Mode != "" {
210219
parsedMode, err := parseBackendMode(f.Mode)

cmd/cli/commands/configure_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package commands
22

33
import (
44
"testing"
5+
6+
"github.com/docker/model-runner/pkg/inference"
57
)
68

79
func TestConfigureCmdHfOverridesFlag(t *testing.T) {
@@ -326,6 +328,115 @@ func TestThinkFlagBehavior(t *testing.T) {
326328
}
327329
}
328330

331+
func TestConfigureCmdKeepAliveFlag(t *testing.T) {
332+
// Create the configure command
333+
cmd := newConfigureCmd()
334+
335+
// Verify the --keep-alive flag exists
336+
keepAliveFlag := cmd.Flags().Lookup("keep-alive")
337+
if keepAliveFlag == nil {
338+
t.Fatal("--keep-alive flag not found")
339+
return
340+
}
341+
342+
// Verify the default value is empty
343+
if keepAliveFlag.DefValue != "" {
344+
t.Errorf("Expected default keep-alive value to be empty, got '%s'", keepAliveFlag.DefValue)
345+
}
346+
347+
// Verify the flag type
348+
if keepAliveFlag.Value.Type() != "string" {
349+
t.Errorf("Expected keep-alive flag type to be 'string', got '%s'", keepAliveFlag.Value.Type())
350+
}
351+
352+
// Test setting the flag value
353+
if err := cmd.Flags().Set("keep-alive", "10m"); err != nil {
354+
t.Errorf("Failed to set keep-alive flag: %v", err)
355+
}
356+
357+
// Verify the value was set
358+
if keepAliveFlag.Value.String() != "10m" {
359+
t.Errorf("Expected keep-alive flag value to be '10m', got '%s'", keepAliveFlag.Value.String())
360+
}
361+
}
362+
363+
func TestKeepAliveFlagBehavior(t *testing.T) {
364+
tests := []struct {
365+
name string
366+
keepAlive string
367+
expectSet bool
368+
expectError bool
369+
expectedValue inference.KeepAlive
370+
}{
371+
{
372+
name: "default - not set",
373+
keepAlive: "",
374+
expectSet: false,
375+
},
376+
{
377+
name: "5 minutes",
378+
keepAlive: "5m",
379+
expectSet: true,
380+
expectedValue: inference.KeepAliveDefault,
381+
},
382+
{
383+
name: "unload immediately",
384+
keepAlive: "0",
385+
expectSet: true,
386+
expectedValue: inference.KeepAliveImmediate,
387+
},
388+
{
389+
name: "never unload",
390+
keepAlive: "-1",
391+
expectSet: true,
392+
expectedValue: inference.KeepAliveForever,
393+
},
394+
{
395+
name: "negative duration means forever",
396+
keepAlive: "-1m",
397+
expectSet: true,
398+
expectedValue: inference.KeepAliveForever,
399+
},
400+
{
401+
name: "invalid value",
402+
keepAlive: "abc",
403+
expectError: true,
404+
},
405+
}
406+
407+
for _, tt := range tests {
408+
t.Run(tt.name, func(t *testing.T) {
409+
flags := ConfigureFlags{
410+
KeepAlive: tt.keepAlive,
411+
}
412+
413+
req, err := flags.BuildConfigureRequest("test-model")
414+
if tt.expectError {
415+
if err == nil {
416+
t.Fatal("Expected error but got none")
417+
}
418+
return
419+
}
420+
if err != nil {
421+
t.Fatalf("Unexpected error: %v", err)
422+
}
423+
424+
if tt.expectSet {
425+
if req.KeepAlive == nil {
426+
t.Fatal("Expected KeepAlive to be set")
427+
}
428+
if *req.KeepAlive != tt.expectedValue {
429+
t.Errorf("Expected KeepAlive to be %v, got %v", tt.expectedValue, *req.KeepAlive)
430+
}
431+
} else {
432+
if req.KeepAlive != nil {
433+
t.Errorf("Expected KeepAlive to be nil, got %v", *req.KeepAlive)
434+
}
435+
}
436+
})
437+
}
438+
}
439+
329440
func TestRuntimeFlagsValidation(t *testing.T) {
330441
tests := []struct {
331442
name string

cmd/cli/commands/ps.go

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/docker/go-units"
99
"github.com/docker/model-runner/cmd/cli/commands/completion"
1010
"github.com/docker/model-runner/cmd/cli/desktop"
11+
"github.com/docker/model-runner/pkg/inference"
1112
"github.com/spf13/cobra"
1213
)
1314

@@ -31,43 +32,45 @@ func newPSCmd() *cobra.Command {
3132
func psTable(ps []desktop.BackendStatus) string {
3233
var buf bytes.Buffer
3334
table := newTable(&buf)
34-
table.Header([]string{"MODEL NAME", "BACKEND", "MODE", "LAST USED"})
35+
table.Header([]string{"MODEL NAME", "BACKEND", "MODE", "UNTIL"})
3536

3637
for _, status := range ps {
3738
modelName := status.ModelName
3839
if strings.HasPrefix(modelName, "sha256:") {
3940
modelName = modelName[7:19]
4041
} else {
41-
// Strip default "ai/" prefix and ":latest" tag for display
4242
modelName = stripDefaultsFromModelName(modelName)
4343
}
4444

45-
var lastUsed string
46-
if status.InUse {
47-
lastUsed = "in use"
48-
} else if !status.LastUsed.IsZero() {
49-
duration := time.Since(status.LastUsed)
50-
if duration < 0 {
51-
duration = 0
52-
}
53-
if duration < time.Second {
54-
lastUsed = "just now"
55-
} else {
56-
lastUsed = units.HumanDuration(duration) + " ago"
57-
}
58-
} else {
59-
// This case should not happen if InUse is properly set, but fallback to "in use" for zero time
60-
lastUsed = "in use"
61-
}
62-
6345
table.Append([]string{
6446
modelName,
6547
status.BackendName,
6648
status.Mode,
67-
lastUsed,
49+
formatUntil(status),
6850
})
6951
}
7052

7153
table.Render()
7254
return buf.String()
7355
}
56+
57+
func formatUntil(status desktop.BackendStatus) string {
58+
keepAlive := inference.KeepAliveDefault
59+
if status.KeepAlive != nil {
60+
keepAlive = *status.KeepAlive
61+
}
62+
63+
if keepAlive == inference.KeepAliveForever {
64+
return "Forever"
65+
}
66+
67+
if status.InUse || status.LastUsed.IsZero() {
68+
return units.HumanDuration(keepAlive.Duration()) + " from now"
69+
}
70+
71+
remaining := keepAlive.Duration() - time.Since(status.LastUsed)
72+
if remaining <= 0 {
73+
return "Expiring"
74+
}
75+
return units.HumanDuration(remaining) + " from now"
76+
}

cmd/cli/desktop/desktop.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -650,16 +650,12 @@ func (c *Client) Remove(modelArgs []string, force bool) (string, error) {
650650

651651
// BackendStatus to be imported from docker/model-runner when https://github.com/docker/model-runner/pull/42 is merged.
652652
type BackendStatus struct {
653-
// BackendName is the name of the backend
654-
BackendName string `json:"backend_name"`
655-
// ModelName is the name of the model loaded in the backend
656-
ModelName string `json:"model_name"`
657-
// Mode is the mode the backend is operating in
658-
Mode string `json:"mode"`
659-
// LastUsed represents when this backend was last used (if it's idle)
660-
LastUsed time.Time `json:"last_used,omitempty"`
661-
// InUse indicates whether this backend is currently handling a request
662-
InUse bool `json:"in_use,omitempty"`
653+
BackendName string `json:"backend_name"`
654+
ModelName string `json:"model_name"`
655+
Mode string `json:"mode"`
656+
LastUsed time.Time `json:"last_used,omitempty"`
657+
InUse bool `json:"in_use,omitempty"`
658+
KeepAlive *inference.KeepAlive `json:"keep_alive,omitempty"`
663659
}
664660

665661
func (c *Client) PS() ([]BackendStatus, error) {

cmd/cli/docs/reference/docker_model_configure.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ command: docker model configure
22
aliases: docker model configure, docker model config
33
short: Manage model runtime configurations
44
long: Manage model runtime configurations
5-
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]
5+
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] [--keep-alive=<duration>] MODEL [-- <runtime-flags...>]
66
pname: docker model
77
plink: docker_model.yaml
88
cname:
@@ -38,6 +38,16 @@ options:
3838
experimentalcli: false
3939
kubernetes: false
4040
swarm: false
41+
- option: keep-alive
42+
value_type: string
43+
description: |
44+
duration to keep model loaded (e.g., '5m', '1h', '0' to unload immediately, '-1' to never unload)
45+
deprecated: false
46+
hidden: false
47+
experimental: false
48+
experimentalcli: false
49+
kubernetes: false
50+
swarm: false
4151
- option: mode
4252
value_type: string
4353
description: |

pkg/inference/backend.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"net/http"
8+
"time"
89
)
910

1011
// BackendMode encodes the mode in which a backend should operate.
@@ -108,11 +109,70 @@ type LlamaCppConfig struct {
108109
ReasoningBudget *int32 `json:"reasoning-budget,omitempty"`
109110
}
110111

112+
// KeepAlive is a duration controlling how long a model stays loaded in memory.
113+
// JSON representation uses Go duration strings (e.g. "5m", "1h") plus the
114+
// special value "-1" (never unload). A nil *KeepAlive means use the default
115+
// (5 minutes).
116+
type KeepAlive time.Duration
117+
118+
const (
119+
KeepAliveDefault = KeepAlive(5 * time.Minute)
120+
KeepAliveImmediate = KeepAlive(0)
121+
KeepAliveForever = KeepAlive(-1)
122+
)
123+
124+
func (d KeepAlive) Duration() time.Duration {
125+
return time.Duration(d)
126+
}
127+
128+
func (d KeepAlive) MarshalJSON() ([]byte, error) {
129+
if d == KeepAliveForever {
130+
return json.Marshal("-1")
131+
}
132+
return json.Marshal(time.Duration(d).String())
133+
}
134+
135+
func (d *KeepAlive) UnmarshalJSON(data []byte) error {
136+
var s string
137+
if err := json.Unmarshal(data, &s); err != nil {
138+
return err
139+
}
140+
parsed, err := ParseKeepAlive(s)
141+
if err != nil {
142+
return err
143+
}
144+
*d = parsed
145+
return nil
146+
}
147+
148+
// ParseKeepAlive converts a keep_alive string to a KeepAlive value.
149+
// Accepts:
150+
// - Go duration strings: "5m", "1h", "30s"
151+
// - "0" to unload immediately
152+
// - Any negative value ("-1", "-1m") to keep loaded forever
153+
func ParseKeepAlive(s string) (KeepAlive, error) {
154+
if s == "0" {
155+
return KeepAliveImmediate, nil
156+
}
157+
if s == "-1" {
158+
return KeepAliveForever, nil
159+
}
160+
d, err := time.ParseDuration(s)
161+
if err != nil {
162+
return 0, fmt.Errorf("invalid keep_alive duration %q: %w", s, err)
163+
}
164+
if d < 0 {
165+
return KeepAliveForever, nil
166+
}
167+
return KeepAlive(d), nil
168+
}
169+
111170
type BackendConfiguration struct {
112171
// Shared configuration across all backends
113172
ContextSize *int32 `json:"context-size,omitempty"`
114173
RuntimeFlags []string `json:"runtime-flags,omitempty"`
115174
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`
175+
KeepAlive *KeepAlive `json:"keep_alive,omitempty"`
116176

117177
// Backend-specific configuration
118178
VLLM *VLLMConfig `json:"vllm,omitempty"`

0 commit comments

Comments
 (0)