Skip to content

Commit bdc2101

Browse files
committed
Merge branch 'main' into diffusers-distribution
# Conflicts: # main.go # pkg/inference/backends/diffusers/diffusers.go
2 parents a6c8cab + 3561e79 commit bdc2101

39 files changed

+443
-404
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ require (
3030
github.com/pkg/errors v0.9.1
3131
github.com/prometheus/client_model v0.6.2
3232
github.com/prometheus/common v0.67.5
33-
github.com/sirupsen/logrus v1.9.4
3433
github.com/spf13/cobra v1.10.2
3534
github.com/spf13/pflag v1.0.10
3635
github.com/stretchr/testify v1.11.1
@@ -116,6 +115,7 @@ require (
116115
github.com/rivo/uniseg v0.4.7 // indirect
117116
github.com/russross/blackfriday/v2 v2.1.0 // indirect
118117
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
118+
github.com/sirupsen/logrus v1.9.3 // indirect
119119
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect
120120
github.com/tklauser/go-sysconf v0.3.12 // indirect
121121
github.com/tklauser/numcpus v0.6.1 // indirect

go.sum

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf
255255
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
256256
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
257257
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
258-
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
259-
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
258+
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
259+
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
260260
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY=
261261
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0=
262262
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
@@ -268,6 +268,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
268268
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
269269
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
270270
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
271+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
271272
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
272273
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
273274
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
@@ -351,6 +352,7 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
351352
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
352353
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
353354
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
355+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
354356
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
355357
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
356358
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -391,6 +393,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
391393
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
392394
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
393395
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
396+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
394397
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
395398
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
396399
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=

main.go

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"crypto/tls"
66
"encoding/json"
7+
"fmt"
8+
"log/slog"
79
"net"
810
"net/http"
911
"os"
@@ -14,29 +16,39 @@ import (
1416
"time"
1517

1618
"github.com/docker/model-runner/pkg/inference"
19+
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
1720
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
1821
"github.com/docker/model-runner/pkg/inference/backends/sglang"
1922
"github.com/docker/model-runner/pkg/inference/config"
2023
"github.com/docker/model-runner/pkg/inference/models"
24+
"github.com/docker/model-runner/pkg/logging"
2125
"github.com/docker/model-runner/pkg/metrics"
2226
"github.com/docker/model-runner/pkg/routing"
2327
modeltls "github.com/docker/model-runner/pkg/tls"
24-
"github.com/sirupsen/logrus"
2528
)
2629

2730
const (
2831
// DefaultTLSPort is the default TLS port for Moby
2932
DefaultTLSPort = "12444"
3033
)
3134

32-
var log = logrus.New()
35+
// initLogger creates the application logger based on LOG_LEVEL env var.
36+
func initLogger() *slog.Logger {
37+
level := logging.ParseLevel(os.Getenv("LOG_LEVEL"))
38+
return logging.NewLogger(level)
39+
}
40+
41+
var log = initLogger()
3342

3443
// Log is the logger used by the application, exported for testing purposes.
3544
var Log = log
3645

3746
// testLog is a test-override logger used by createLlamaCppConfigFromEnv.
3847
var testLog = log
3948

49+
// exitFunc is used for Fatal-like exits; overridden in tests.
50+
var exitFunc = func(code int) { os.Exit(code) }
51+
4052
func main() {
4153
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
4254
defer cancel()
@@ -48,7 +60,8 @@ func main() {
4860

4961
userHomeDir, err := os.UserHomeDir()
5062
if err != nil {
51-
log.Fatalf("Failed to get user home directory: %v", err)
63+
log.Error("Failed to get user home directory", "error", err)
64+
exitFunc(1)
5265
}
5366

5467
modelPath := os.Getenv("MODELS_PATH")
@@ -90,21 +103,21 @@ func main() {
90103
}
91104
baseTransport.Proxy = http.ProxyFromEnvironment
92105

93-
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
106+
log.Info("LLAMA_SERVER_PATH", "path", llamaServerPath)
94107
if vllmServerPath != "" {
95-
log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath)
108+
log.Info("VLLM_SERVER_PATH", "path", vllmServerPath)
96109
}
97110
if sglangServerPath != "" {
98-
log.Infof("SGLANG_SERVER_PATH: %s", sglangServerPath)
111+
log.Info("SGLANG_SERVER_PATH", "path", sglangServerPath)
99112
}
100113
if mlxServerPath != "" {
101-
log.Infof("MLX_SERVER_PATH: %s", mlxServerPath)
114+
log.Info("MLX_SERVER_PATH", "path", mlxServerPath)
102115
}
103116
if diffusersServerPath != "" {
104-
log.Infof("DIFFUSERS_SERVER_PATH: %s", diffusersServerPath)
117+
log.Info("DIFFUSERS_SERVER_PATH", "path", diffusersServerPath)
105118
}
106119
if vllmMetalServerPath != "" {
107-
log.Infof("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath)
120+
log.Info("VLLM_METAL_SERVER_PATH", "path", vllmMetalServerPath)
108121
}
109122

110123
// Create llama.cpp configuration from environment variables
@@ -121,7 +134,7 @@ func main() {
121134
Log: log,
122135
ClientConfig: models.ClientConfig{
123136
StoreRootPath: modelPath,
124-
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
137+
Logger: log.With("component", "model-manager"),
125138
Transport: baseTransport,
126139
},
127140
Backends: append(
@@ -139,17 +152,21 @@ func main() {
139152
DiffusersPath: diffusersServerPath,
140153
}),
141154
routing.BackendDef{Name: sglang.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
142-
return sglang.New(log, mm, log.WithFields(logrus.Fields{"component": sglang.Name}), nil, sglangServerPath)
155+
return sglang.New(log, mm, log.With("component", sglang.Name), nil, sglangServerPath)
156+
}},
157+
routing.BackendDef{Name: diffusers.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
158+
return diffusers.New(log, mm, log.With("component", diffusers.Name), nil, diffusersServerPath)
143159
}},
144160
),
145161
OnBackendError: func(name string, err error) {
146-
log.Fatalf("unable to initialize %s backend: %v", name, err)
162+
log.Error("unable to initialize backend", "backend", name, "error", err)
163+
exitFunc(1)
147164
},
148165
DefaultBackendName: llamacpp.Name,
149166
HTTPClient: http.DefaultClient,
150167
MetricsTracker: metrics.NewTracker(
151168
http.DefaultClient,
152-
log.WithField("component", "metrics"),
169+
log.With("component", "metrics"),
153170
"",
154171
false,
155172
),
@@ -169,14 +186,14 @@ func main() {
169186
r.HandleFunc("/version", func(w http.ResponseWriter, req *http.Request) {
170187
w.Header().Set("Content-Type", "application/json")
171188
if err := json.NewEncoder(w).Encode(map[string]string{"version": Version}); err != nil {
172-
log.Warnf("failed to write version response: %v", err)
189+
log.Warn("failed to write version response", "error", err)
173190
}
174191
})
175192

176193
// Metrics endpoint
177194
if os.Getenv("DISABLE_METRICS") != "1" {
178195
metricsHandler := metrics.NewAggregatedMetricsHandler(
179-
log.WithField("component", "metrics"),
196+
log.With("component", "metrics"),
180197
s.SchedulerHTTP,
181198
)
182199
r.Handle("/metrics", metricsHandler)
@@ -187,7 +204,8 @@ func main() {
187204
},
188205
})
189206
if err != nil {
190-
log.Fatalf("failed to initialize service: %v", err)
207+
log.Error("failed to initialize service", "error", err)
208+
exitFunc(1)
191209
}
192210

193211
server := &http.Server{
@@ -205,7 +223,7 @@ func main() {
205223
if tcpPort != "" {
206224
// Use TCP port
207225
addr := ":" + tcpPort
208-
log.Infof("Listening on TCP port %s", tcpPort)
226+
log.Info("Listening on TCP port", "port", tcpPort)
209227
server.Addr = addr
210228
go func() {
211229
serverErrors <- server.ListenAndServe()
@@ -214,12 +232,14 @@ func main() {
214232
// Use Unix socket
215233
if err := os.Remove(sockName); err != nil {
216234
if !os.IsNotExist(err) {
217-
log.Fatalf("Failed to remove existing socket: %v", err)
235+
log.Error("Failed to remove existing socket", "error", err)
236+
exitFunc(1)
218237
}
219238
}
220239
ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"})
221240
if err != nil {
222-
log.Fatalf("Failed to listen on socket: %v", err)
241+
log.Error("Failed to listen on socket", "error", err)
242+
exitFunc(1)
223243
}
224244
go func() {
225245
serverErrors <- server.Serve(ln)
@@ -244,19 +264,22 @@ func main() {
244264
var err error
245265
certPath, keyPath, err = modeltls.EnsureCertificates("", "")
246266
if err != nil {
247-
log.Fatalf("Failed to ensure TLS certificates: %v", err)
267+
log.Error("Failed to ensure TLS certificates", "error", err)
268+
exitFunc(1)
248269
}
249-
log.Infof("Using TLS certificate: %s", certPath)
250-
log.Infof("Using TLS key: %s", keyPath)
270+
log.Info("Using TLS certificate", "cert", certPath)
271+
log.Info("Using TLS key", "key", keyPath)
251272
} else {
252-
log.Fatal("TLS enabled but no certificate provided and auto-cert is disabled")
273+
log.Error("TLS enabled but no certificate provided and auto-cert is disabled")
274+
exitFunc(1)
253275
}
254276
}
255277

256278
// Load TLS configuration
257279
tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath)
258280
if err != nil {
259-
log.Fatalf("Failed to load TLS configuration: %v", err)
281+
log.Error("Failed to load TLS configuration", "error", err)
282+
exitFunc(1)
260283
}
261284

262285
tlsServer = &http.Server{
@@ -266,7 +289,7 @@ func main() {
266289
ReadHeaderTimeout: 10 * time.Second,
267290
}
268291

269-
log.Infof("Listening on TLS port %s", tlsPort)
292+
log.Info("Listening on TLS port", "port", tlsPort)
270293
go func() {
271294
// Use ListenAndServeTLS with empty strings since TLSConfig already has the certs
272295
ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig)
@@ -294,30 +317,30 @@ func main() {
294317
select {
295318
case err := <-serverErrors:
296319
if err != nil {
297-
log.Errorf("Server error: %v", err)
320+
log.Error("Server error", "error", err)
298321
}
299322
case err := <-tlsServerErrorsChan:
300323
if err != nil {
301-
log.Errorf("TLS server error: %v", err)
324+
log.Error("TLS server error", "error", err)
302325
}
303326
case <-ctx.Done():
304-
log.Infoln("Shutdown signal received")
305-
log.Infoln("Shutting down the server")
327+
log.Info("Shutdown signal received")
328+
log.Info("Shutting down the server")
306329
if err := server.Close(); err != nil {
307-
log.Errorf("Server shutdown error: %v", err)
330+
log.Error("Server shutdown error", "error", err)
308331
}
309332
if tlsServer != nil {
310-
log.Infoln("Shutting down the TLS server")
333+
log.Info("Shutting down the TLS server")
311334
if err := tlsServer.Close(); err != nil {
312-
log.Errorf("TLS server shutdown error: %v", err)
335+
log.Error("TLS server shutdown error", "error", err)
313336
}
314337
}
315-
log.Infoln("Waiting for the scheduler to stop")
338+
log.Info("Waiting for the scheduler to stop")
316339
if err := <-schedulerErrors; err != nil {
317-
log.Errorf("Scheduler error: %v", err)
340+
log.Error("Scheduler error", "error", err)
318341
}
319342
}
320-
log.Infoln("Docker Model Runner stopped")
343+
log.Info("Docker Model Runner stopped")
321344
}
322345

323346
// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
@@ -338,12 +361,13 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
338361
for _, arg := range args {
339362
for _, disallowed := range disallowedArgs {
340363
if arg == disallowed {
341-
testLog.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
364+
testLog.Error(fmt.Sprintf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed))
365+
exitFunc(1)
342366
}
343367
}
344368
}
345369

346-
testLog.Infof("Using custom arguments: %v", args)
370+
testLog.Info("Using custom arguments", "args", fmt.Sprintf("%v", args))
347371
return &llamacpp.Config{
348372
Args: args,
349373
}

main_test.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"testing"
55

66
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
7-
"github.com/sirupsen/logrus"
87
)
98

109
func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
@@ -61,17 +60,14 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
6160
t.Setenv("LLAMA_ARGS", tt.llamaArgs)
6261
}
6362

64-
// Create a test logger that captures fatal errors
65-
originalLog := testLog
66-
defer func() { testLog = originalLog }()
63+
// Override exitFunc to capture exit calls instead of actually exiting
64+
originalExitFunc := exitFunc
65+
defer func() { exitFunc = originalExitFunc }()
6766

68-
// Create a new logger that will exit with a special exit code
69-
newTestLog := logrus.New()
7067
var exitCode int
71-
newTestLog.ExitFunc = func(code int) {
68+
exitFunc = func(code int) {
7269
exitCode = code
7370
}
74-
testLog = newTestLog
7571

7672
config := createLlamaCppConfigFromEnv()
7773

pkg/anthropic/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func NewHandler(log logging.Logger, schedulerHTTP *scheduling.HTTPHandler, allow
5959
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6060
safeMethod := utils.SanitizeForLog(r.Method, -1)
6161
safePath := utils.SanitizeForLog(r.URL.Path, -1)
62-
h.log.Infof("Anthropic API request: %s %s", safeMethod, safePath)
62+
h.log.Info("Anthropic API request", "method", safeMethod, "path", safePath)
6363
h.httpHandler.ServeHTTP(w, r)
6464
}
6565

@@ -169,6 +169,6 @@ func (h *Handler) writeAnthropicError(w http.ResponseWriter, statusCode int, err
169169
}
170170

171171
if err := json.NewEncoder(w).Encode(errResp); err != nil {
172-
h.log.Errorf("Failed to encode error response: %v", err)
172+
h.log.Error("Failed to encode error response", "error", err)
173173
}
174174
}

0 commit comments

Comments
 (0)