Skip to content

Commit 1beffc5

Browse files
authored
Merge pull request #711 from docker/envconfig
Add pkg/envconfig package and consolidate environment variable access
2 parents c21f944 + 2f1da40 commit 1beffc5

File tree

3 files changed

+255
-72
lines changed

3 files changed

+255
-72
lines changed

main.go

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"syscall"
1616
"time"
1717

18+
"github.com/docker/model-runner/pkg/envconfig"
1819
"github.com/docker/model-runner/pkg/inference"
1920
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
2021
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
@@ -27,15 +28,9 @@ import (
2728
modeltls "github.com/docker/model-runner/pkg/tls"
2829
)
2930

30-
const (
31-
// DefaultTLSPort is the default TLS port for Moby
32-
DefaultTLSPort = "12444"
33-
)
34-
3531
// initLogger creates the application logger based on LOG_LEVEL env var.
3632
func initLogger() *slog.Logger {
37-
level := logging.ParseLevel(os.Getenv("LOG_LEVEL"))
38-
return logging.NewLogger(level)
33+
return logging.NewLogger(envconfig.LogLevel())
3934
}
4035

4136
var log = initLogger()
@@ -47,45 +42,29 @@ func main() {
4742
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
4843
defer cancel()
4944

50-
sockName := os.Getenv("MODEL_RUNNER_SOCK")
51-
if sockName == "" {
52-
sockName = "model-runner.sock"
53-
}
54-
55-
userHomeDir, err := os.UserHomeDir()
45+
sockName := envconfig.SocketPath()
46+
modelPath, err := envconfig.ModelsPath()
5647
if err != nil {
57-
log.Error("Failed to get user home directory", "error", err)
48+
log.Error("Failed to get models path", "error", err)
5849
exitFunc(1)
5950
}
6051

61-
modelPath := os.Getenv("MODELS_PATH")
62-
if modelPath == "" {
63-
modelPath = filepath.Join(userHomeDir, ".docker", "models")
64-
}
65-
66-
_, disableServerUpdate := os.LookupEnv("DISABLE_SERVER_UPDATE")
67-
if disableServerUpdate {
52+
if envconfig.DisableServerUpdate() {
6853
llamacpp.ShouldUpdateServerLock.Lock()
6954
llamacpp.ShouldUpdateServer = false
7055
llamacpp.ShouldUpdateServerLock.Unlock()
7156
}
7257

73-
desiredServerVersion, ok := os.LookupEnv("LLAMA_SERVER_VERSION")
74-
if ok {
75-
llamacpp.SetDesiredServerVersion(desiredServerVersion)
76-
}
77-
78-
llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
79-
if llamaServerPath == "" {
80-
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
58+
if v := envconfig.LlamaServerVersion(); v != "" {
59+
llamacpp.SetDesiredServerVersion(v)
8160
}
8261

83-
// Get optional custom paths for other backends
84-
vllmServerPath := os.Getenv("VLLM_SERVER_PATH")
85-
sglangServerPath := os.Getenv("SGLANG_SERVER_PATH")
86-
mlxServerPath := os.Getenv("MLX_SERVER_PATH")
87-
diffusersServerPath := os.Getenv("DIFFUSERS_SERVER_PATH")
88-
vllmMetalServerPath := os.Getenv("VLLM_METAL_SERVER_PATH")
62+
llamaServerPath := envconfig.LlamaServerPath()
63+
vllmServerPath := envconfig.VLLMServerPath()
64+
sglangServerPath := envconfig.SGLangServerPath()
65+
mlxServerPath := envconfig.MLXServerPath()
66+
diffusersServerPath := envconfig.DiffusersServerPath()
67+
vllmMetalServerPath := envconfig.VLLMMetalServerPath()
8968

9069
// Create a proxy-aware HTTP transport
9170
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
@@ -169,6 +148,7 @@ func main() {
169148
"",
170149
false,
171150
),
151+
AllowedOrigins: envconfig.AllowedOrigins(),
172152
IncludeResponsesAPI: true,
173153
ExtraRoutes: func(r *routing.NormalizedServeMux, s *routing.Service) {
174154
// Root handler – only catches exact "/" requests
@@ -190,7 +170,7 @@ func main() {
190170
})
191171

192172
// Metrics endpoint
193-
if os.Getenv("DISABLE_METRICS") != "1" {
173+
if !envconfig.DisableMetrics() {
194174
metricsHandler := metrics.NewAggregatedMetricsHandler(
195175
log.With("component", "metrics"),
196176
s.SchedulerHTTP,
@@ -218,7 +198,7 @@ func main() {
218198
tlsServerErrors := make(chan error, 1)
219199

220200
// Check if we should use TCP port instead of Unix socket
221-
tcpPort := os.Getenv("MODEL_RUNNER_PORT")
201+
tcpPort := envconfig.TCPPort()
222202
if tcpPort != "" {
223203
// Use TCP port
224204
addr := ":" + tcpPort
@@ -246,19 +226,16 @@ func main() {
246226
}
247227

248228
// Start TLS server if enabled
249-
if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" {
250-
tlsPort := os.Getenv("MODEL_RUNNER_TLS_PORT")
251-
if tlsPort == "" {
252-
tlsPort = DefaultTLSPort // Default TLS port for Moby
253-
}
229+
if envconfig.TLSEnabled() {
230+
tlsPort := envconfig.TLSPort()
254231

255232
// Get certificate paths
256-
certPath := os.Getenv("MODEL_RUNNER_TLS_CERT")
257-
keyPath := os.Getenv("MODEL_RUNNER_TLS_KEY")
233+
certPath := envconfig.TLSCert()
234+
keyPath := envconfig.TLSKey()
258235

259236
// Auto-generate certificates if not provided and auto-cert is not disabled
260237
if certPath == "" || keyPath == "" {
261-
if os.Getenv("MODEL_RUNNER_TLS_AUTO_CERT") != "false" {
238+
if envconfig.TLSAutoCert(true) {
262239
log.Info("Auto-generating TLS certificates...")
263240
var err error
264241
certPath, keyPath, err = modeltls.EnsureCertificates("", "")
@@ -306,7 +283,7 @@ func main() {
306283
}()
307284

308285
var tlsServerErrorsChan <-chan error
309-
if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" {
286+
if envconfig.TLSEnabled() {
310287
tlsServerErrorsChan = tlsServerErrors
311288
} else {
312289
// Use a nil channel which will block forever when TLS is disabled
@@ -346,7 +323,7 @@ func main() {
346323
// Returns nil config (use defaults) when LLAMA_ARGS is unset, or an error if
347324
// the args contain disallowed flags.
348325
func createLlamaCppConfigFromEnv() (config.BackendConfig, error) {
349-
argsStr := os.Getenv("LLAMA_ARGS")
326+
argsStr := envconfig.LlamaArgs()
350327
if argsStr == "" {
351328
return nil, nil
352329
}

pkg/envconfig/envconfig.go

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package envconfig
2+
3+
import (
4+
"fmt"
5+
"log/slog"
6+
"net"
7+
"os"
8+
"path/filepath"
9+
"strconv"
10+
"strings"
11+
12+
"github.com/docker/model-runner/pkg/logging"
13+
)
14+
15+
// Var returns an environment variable stripped of leading/trailing quotes and spaces.
16+
func Var(key string) string {
17+
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
18+
}
19+
20+
// String returns a lazy string accessor for the given environment variable.
21+
func String(key string) func() string {
22+
return func() string {
23+
return Var(key)
24+
}
25+
}
26+
27+
// BoolWithDefault returns a lazy bool accessor for the given environment variable,
28+
// allowing a caller-specified default. If the variable is set but cannot be parsed
29+
// as a bool, the defaultValue is returned.
30+
func BoolWithDefault(key string) func(defaultValue bool) bool {
31+
return func(defaultValue bool) bool {
32+
if s := Var(key); s != "" {
33+
b, err := strconv.ParseBool(s)
34+
if err != nil {
35+
return defaultValue
36+
}
37+
return b
38+
}
39+
return defaultValue
40+
}
41+
}
42+
43+
// Bool returns a lazy bool accessor that defaults to false when the variable is unset.
44+
func Bool(key string) func() bool {
45+
withDefault := BoolWithDefault(key)
46+
return func() bool {
47+
return withDefault(false)
48+
}
49+
}
50+
51+
// LogLevel reads LOG_LEVEL and returns the corresponding slog.Level.
52+
func LogLevel() slog.Level {
53+
return logging.ParseLevel(Var("LOG_LEVEL"))
54+
}
55+
56+
// AllowedOrigins returns a list of CORS-allowed origins. It reads DMR_ORIGINS
57+
// and always appends default localhost/127.0.0.1/0.0.0.0 entries on http and
58+
// https with wildcard ports.
59+
func AllowedOrigins() (origins []string) {
60+
if s := Var("DMR_ORIGINS"); s != "" {
61+
for _, o := range strings.Split(s, ",") {
62+
if trimmed := strings.TrimSpace(o); trimmed != "" {
63+
origins = append(origins, trimmed)
64+
}
65+
}
66+
}
67+
68+
for _, host := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
69+
origins = append(origins,
70+
fmt.Sprintf("http://%s", host),
71+
fmt.Sprintf("https://%s", host),
72+
fmt.Sprintf("http://%s", net.JoinHostPort(host, "*")),
73+
fmt.Sprintf("https://%s", net.JoinHostPort(host, "*")),
74+
)
75+
}
76+
77+
return origins
78+
}
79+
80+
// SocketPath returns the Unix socket path for the model runner.
81+
// Configured via MODEL_RUNNER_SOCK; defaults to "model-runner.sock".
82+
func SocketPath() string {
83+
if s := Var("MODEL_RUNNER_SOCK"); s != "" {
84+
return s
85+
}
86+
return "model-runner.sock"
87+
}
88+
89+
// ModelsPath returns the directory where models are stored.
90+
// Configured via MODELS_PATH; defaults to ~/.docker/models.
91+
func ModelsPath() (string, error) {
92+
if s := Var("MODELS_PATH"); s != "" {
93+
return s, nil
94+
}
95+
home, err := os.UserHomeDir()
96+
if err != nil {
97+
return "", err
98+
}
99+
return filepath.Join(home, ".docker", "models"), nil
100+
}
101+
102+
// TCPPort returns the optional TCP port for the model runner HTTP server.
103+
// Configured via MODEL_RUNNER_PORT; empty string means use Unix socket.
104+
func TCPPort() string {
105+
return Var("MODEL_RUNNER_PORT")
106+
}
107+
108+
// LlamaServerPath returns the path to the llama.cpp server binary.
109+
// Configured via LLAMA_SERVER_PATH; defaults to the Docker Desktop bundle location.
110+
func LlamaServerPath() string {
111+
if s := Var("LLAMA_SERVER_PATH"); s != "" {
112+
return s
113+
}
114+
return "/Applications/Docker.app/Contents/Resources/model-runner/bin"
115+
}
116+
117+
// LlamaArgs returns custom arguments to pass to the llama.cpp server.
118+
// Configured via LLAMA_ARGS.
119+
func LlamaArgs() string {
120+
return Var("LLAMA_ARGS")
121+
}
122+
123+
// DisableServerUpdate is true when DISABLE_SERVER_UPDATE is set to a truthy value.
124+
var DisableServerUpdate = Bool("DISABLE_SERVER_UPDATE")
125+
126+
// LlamaServerVersion returns a specific llama.cpp server version to pin.
127+
// Configured via LLAMA_SERVER_VERSION; empty string means use the bundled version.
128+
func LlamaServerVersion() string {
129+
return Var("LLAMA_SERVER_VERSION")
130+
}
131+
132+
// VLLMServerPath returns the optional path to the vLLM server binary.
133+
// Configured via VLLM_SERVER_PATH.
134+
func VLLMServerPath() string {
135+
return Var("VLLM_SERVER_PATH")
136+
}
137+
138+
// SGLangServerPath returns the optional path to the SGLang server binary.
139+
// Configured via SGLANG_SERVER_PATH.
140+
func SGLangServerPath() string {
141+
return Var("SGLANG_SERVER_PATH")
142+
}
143+
144+
// MLXServerPath returns the optional path to the MLX server binary.
145+
// Configured via MLX_SERVER_PATH.
146+
func MLXServerPath() string {
147+
return Var("MLX_SERVER_PATH")
148+
}
149+
150+
// DiffusersServerPath returns the optional path to the Diffusers server binary.
151+
// Configured via DIFFUSERS_SERVER_PATH.
152+
func DiffusersServerPath() string {
153+
return Var("DIFFUSERS_SERVER_PATH")
154+
}
155+
156+
// VLLMMetalServerPath returns the optional path to the vLLM Metal server binary.
157+
// Configured via VLLM_METAL_SERVER_PATH.
158+
func VLLMMetalServerPath() string {
159+
return Var("VLLM_METAL_SERVER_PATH")
160+
}
161+
162+
// DisableMetrics is true when DISABLE_METRICS is set to a truthy value (e.g. "1").
163+
var DisableMetrics = Bool("DISABLE_METRICS")
164+
165+
// TLSEnabled is true when MODEL_RUNNER_TLS_ENABLED is set to a truthy value.
166+
var TLSEnabled = Bool("MODEL_RUNNER_TLS_ENABLED")
167+
168+
// TLSPort returns the TLS listener port.
169+
// Configured via MODEL_RUNNER_TLS_PORT; defaults to "12444".
170+
func TLSPort() string {
171+
if s := Var("MODEL_RUNNER_TLS_PORT"); s != "" {
172+
return s
173+
}
174+
return "12444"
175+
}
176+
177+
// TLSCert returns the path to the TLS certificate file.
178+
// Configured via MODEL_RUNNER_TLS_CERT.
179+
func TLSCert() string {
180+
return Var("MODEL_RUNNER_TLS_CERT")
181+
}
182+
183+
// TLSKey returns the path to the TLS private key file.
184+
// Configured via MODEL_RUNNER_TLS_KEY.
185+
func TLSKey() string {
186+
return Var("MODEL_RUNNER_TLS_KEY")
187+
}
188+
189+
// TLSAutoCert is true (default) unless MODEL_RUNNER_TLS_AUTO_CERT is set to a falsy value.
190+
// Call as TLSAutoCert(true) to get the default-true behaviour.
191+
var TLSAutoCert = BoolWithDefault("MODEL_RUNNER_TLS_AUTO_CERT")
192+
193+
// EnvVar describes a single environment variable with its current value
194+
// and a human-readable description.
195+
type EnvVar struct {
196+
Name string
197+
Value any
198+
Description string
199+
}
200+
201+
// AsMap returns a map of all model-runner environment variables with their
202+
// current values and descriptions. Useful for introspection and documentation.
203+
func AsMap() map[string]EnvVar {
204+
modelsPath, _ := ModelsPath()
205+
return map[string]EnvVar{
206+
"MODEL_RUNNER_SOCK": {"MODEL_RUNNER_SOCK", SocketPath(), "Unix socket path (default: model-runner.sock)"},
207+
"MODELS_PATH": {"MODELS_PATH", modelsPath, "Directory for model storage (default: ~/.docker/models)"},
208+
"MODEL_RUNNER_PORT": {"MODEL_RUNNER_PORT", TCPPort(), "TCP port; overrides Unix socket when set"},
209+
"LLAMA_SERVER_PATH": {"LLAMA_SERVER_PATH", LlamaServerPath(), "Path to llama.cpp server binary"},
210+
"LLAMA_ARGS": {"LLAMA_ARGS", LlamaArgs(), "Extra arguments passed to the llama.cpp server"},
211+
"DISABLE_SERVER_UPDATE": {"DISABLE_SERVER_UPDATE", DisableServerUpdate(), "Skip automatic llama.cpp server updates (any truthy value)"},
212+
"LLAMA_SERVER_VERSION": {"LLAMA_SERVER_VERSION", LlamaServerVersion(), "Pin a specific llama.cpp server version"},
213+
"VLLM_SERVER_PATH": {"VLLM_SERVER_PATH", VLLMServerPath(), "Path to vLLM server binary"},
214+
"SGLANG_SERVER_PATH": {"SGLANG_SERVER_PATH", SGLangServerPath(), "Path to SGLang server binary"},
215+
"MLX_SERVER_PATH": {"MLX_SERVER_PATH", MLXServerPath(), "Path to MLX server binary"},
216+
"DIFFUSERS_SERVER_PATH": {"DIFFUSERS_SERVER_PATH", DiffusersServerPath(), "Path to Diffusers server binary"},
217+
"VLLM_METAL_SERVER_PATH": {"VLLM_METAL_SERVER_PATH", VLLMMetalServerPath(), "Path to vLLM Metal server binary"},
218+
"DISABLE_METRICS": {"DISABLE_METRICS", DisableMetrics(), "Disable Prometheus metrics endpoint (any truthy value, e.g. 1)"},
219+
"LOG_LEVEL": {"LOG_LEVEL", LogLevel(), "Log verbosity: debug, info, warn, error (default: info)"},
220+
"DMR_ORIGINS": {"DMR_ORIGINS", AllowedOrigins(), "Comma-separated CORS allowed origins (defaults plus any env-provided origins)"},
221+
"MODEL_RUNNER_TLS_ENABLED": {"MODEL_RUNNER_TLS_ENABLED", TLSEnabled(), "Enable TLS listener"},
222+
"MODEL_RUNNER_TLS_PORT": {"MODEL_RUNNER_TLS_PORT", TLSPort(), "TLS listener port (default: 12444)"},
223+
"MODEL_RUNNER_TLS_CERT": {"MODEL_RUNNER_TLS_CERT", TLSCert(), "Path to TLS certificate file"},
224+
"MODEL_RUNNER_TLS_KEY": {"MODEL_RUNNER_TLS_KEY", TLSKey(), "Path to TLS private key file"},
225+
"MODEL_RUNNER_TLS_AUTO_CERT": {"MODEL_RUNNER_TLS_AUTO_CERT", TLSAutoCert(true), "Auto-generate TLS certificates (default: true)"},
226+
}
227+
}

0 commit comments

Comments
 (0)