Skip to content

Commit 69b080c

Browse files
committed
refactor: extract shared infrastructure into pkg/routing
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent f974104 commit 69b080c

File tree

7 files changed

+397
-197
lines changed

7 files changed

+397
-197
lines changed

main.go

Lines changed: 63 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,13 @@ import (
1212
"syscall"
1313
"time"
1414

15-
"github.com/docker/model-runner/pkg/anthropic"
1615
"github.com/docker/model-runner/pkg/inference"
1716
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
1817
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
19-
"github.com/docker/model-runner/pkg/inference/backends/mlx"
2018
"github.com/docker/model-runner/pkg/inference/backends/sglang"
21-
"github.com/docker/model-runner/pkg/inference/backends/vllm"
22-
"github.com/docker/model-runner/pkg/inference/backends/vllmmetal"
2319
"github.com/docker/model-runner/pkg/inference/config"
2420
"github.com/docker/model-runner/pkg/inference/models"
25-
"github.com/docker/model-runner/pkg/inference/platform"
26-
"github.com/docker/model-runner/pkg/inference/scheduling"
2721
"github.com/docker/model-runner/pkg/metrics"
28-
"github.com/docker/model-runner/pkg/middleware"
29-
"github.com/docker/model-runner/pkg/ollama"
30-
"github.com/docker/model-runner/pkg/responses"
3122
"github.com/docker/model-runner/pkg/routing"
3223
modeltls "github.com/docker/model-runner/pkg/tls"
3324
"github.com/sirupsen/logrus"
@@ -99,17 +90,6 @@ func main() {
9990
}
10091
baseTransport.Proxy = http.ProxyFromEnvironment
10192

102-
clientConfig := models.ClientConfig{
103-
StoreRootPath: modelPath,
104-
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
105-
Transport: baseTransport,
106-
}
107-
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
108-
modelHandler := models.NewHTTPHandler(
109-
log,
110-
modelManager,
111-
nil,
112-
)
11393
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
11494
if vllmServerPath != "" {
11595
log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath)
@@ -127,168 +107,76 @@ func main() {
127107
// Create llama.cpp configuration from environment variables
128108
llamaCppConfig := createLlamaCppConfigFromEnv()
129109

130-
llamaCppBackend, err := llamacpp.New(
131-
log,
132-
modelManager,
133-
log.WithFields(logrus.Fields{"component": llamacpp.Name}),
134-
llamaServerPath,
135-
func() string {
136-
wd, _ := os.Getwd()
137-
d := filepath.Join(wd, "updated-inference", "bin")
138-
_ = os.MkdirAll(d, 0o755)
139-
return d
140-
}(),
141-
llamaCppConfig,
142-
)
143-
if err != nil {
144-
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
145-
}
146-
147-
vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath)
148-
if err != nil {
149-
log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err)
150-
}
151-
152-
mlxBackend, err := mlx.New(
153-
log,
154-
modelManager,
155-
log.WithFields(logrus.Fields{"component": mlx.Name}),
156-
nil,
157-
mlxServerPath,
158-
)
159-
if err != nil {
160-
log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err)
161-
}
162-
163-
sglangBackend, err := sglang.New(
164-
log,
165-
modelManager,
166-
log.WithFields(logrus.Fields{"component": sglang.Name}),
167-
nil,
168-
sglangServerPath,
169-
)
170-
if err != nil {
171-
log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err)
172-
}
173-
174-
diffusersBackend, err := diffusers.New(
175-
log,
176-
modelManager,
177-
log.WithFields(logrus.Fields{"component": diffusers.Name}),
178-
nil,
179-
diffusersServerPath,
180-
)
181-
182-
if err != nil {
183-
log.Fatalf("unable to initialize diffusers backend: %v", err)
184-
}
185-
186-
var vllmMetalBackend inference.Backend
187-
if platform.SupportsVLLMMetal() {
188-
vllmMetalBackend, err = vllmmetal.New(
189-
log,
190-
modelManager,
191-
log.WithFields(logrus.Fields{"component": vllmmetal.Name}),
192-
vllmMetalServerPath,
193-
)
194-
if err != nil {
195-
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
196-
}
197-
}
198-
199-
backends := map[string]inference.Backend{
200-
llamacpp.Name: llamaCppBackend,
201-
mlx.Name: mlxBackend,
202-
sglang.Name: sglangBackend,
203-
diffusers.Name: diffusersBackend,
204-
}
205-
registerVLLMBackend(backends, vllmBackend)
206-
207-
if vllmMetalBackend != nil {
208-
backends[vllmmetal.Name] = vllmMetalBackend
209-
}
210-
211-
// Backends whose installation is deferred until explicitly requested.
212-
var deferredBackends []string
213-
if vllmMetalBackend != nil {
214-
deferredBackends = append(deferredBackends, vllmmetal.Name)
215-
}
110+
updatedServerPath := func() string {
111+
wd, _ := os.Getwd()
112+
d := filepath.Join(wd, "updated-inference", "bin")
113+
_ = os.MkdirAll(d, 0o755)
114+
return d
115+
}()
216116

217-
scheduler := scheduling.NewScheduler(
218-
log,
219-
backends,
220-
llamaCppBackend,
221-
modelManager,
222-
http.DefaultClient,
223-
metrics.NewTracker(
117+
svc := routing.NewService(routing.ServiceConfig{
118+
Log: log,
119+
ClientConfig: models.ClientConfig{
120+
StoreRootPath: modelPath,
121+
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
122+
Transport: baseTransport,
123+
},
124+
Backends: append(append(
125+
routing.DefaultBackendDefs(routing.BackendsConfig{
126+
Log: log,
127+
LlamaCppVendoredPath: llamaServerPath,
128+
LlamaCppUpdatedPath: updatedServerPath,
129+
LlamaCppConfig: llamaCppConfig,
130+
IncludeMLX: true,
131+
MLXPath: mlxServerPath,
132+
}),
133+
routing.BackendDef{Name: sglang.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
134+
return sglang.New(log, mm, log.WithFields(logrus.Fields{"component": sglang.Name}), nil, sglangServerPath)
135+
}},
136+
routing.BackendDef{Name: diffusers.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
137+
return diffusers.New(log, mm, log.WithFields(logrus.Fields{"component": diffusers.Name}), nil, diffusersServerPath)
138+
}},
139+
), vllmBackendDefs(log, vllmServerPath)...),
140+
OnBackendError: func(name string, err error) {
141+
log.Fatalf("unable to initialize %s backend: %v", name, err)
142+
},
143+
DefaultBackendName: llamacpp.Name,
144+
VLLMMetalServerPath: vllmMetalServerPath,
145+
HTTPClient: http.DefaultClient,
146+
MetricsTracker: metrics.NewTracker(
224147
http.DefaultClient,
225148
log.WithField("component", "metrics"),
226149
"",
227150
false,
228151
),
229-
deferredBackends,
230-
)
231-
232-
// Create the HTTP handler for the scheduler
233-
schedulerHTTP := scheduling.NewHTTPHandler(scheduler, modelHandler, nil)
234-
235-
router := routing.NewNormalizedServeMux()
236-
237-
// Register path prefixes to forward all HTTP methods (including OPTIONS) to components
238-
// Components handle method routing internally
239-
// Register both with and without trailing slash to avoid redirects
240-
router.Handle(inference.ModelsPrefix, modelHandler)
241-
router.Handle(inference.ModelsPrefix+"/", modelHandler)
242-
router.Handle(inference.InferencePrefix+"/", schedulerHTTP)
243-
// Add OpenAI Responses API compatibility layer
244-
responsesHandler := responses.NewHTTPHandler(log, schedulerHTTP, nil)
245-
router.Handle(responses.APIPrefix+"/", responsesHandler)
246-
router.Handle(responses.APIPrefix, responsesHandler) // Also register for exact match without trailing slash
247-
router.Handle("/v1"+responses.APIPrefix+"/", responsesHandler)
248-
router.Handle("/v1"+responses.APIPrefix, responsesHandler)
249-
// Also register Responses API under inference prefix to support all inference engines
250-
router.Handle(inference.InferencePrefix+responses.APIPrefix+"/", responsesHandler)
251-
router.Handle(inference.InferencePrefix+responses.APIPrefix, responsesHandler)
252-
253-
// Add path aliases: /v1 -> /engines/v1, /rerank -> /engines/rerank, /score -> /engines/score.
254-
aliasHandler := &middleware.AliasHandler{Handler: schedulerHTTP}
255-
router.Handle("/v1/", aliasHandler)
256-
router.Handle("/rerank", aliasHandler)
257-
router.Handle("/score", aliasHandler)
258-
259-
// Add Ollama API compatibility layer (only register with trailing slash to catch sub-paths)
260-
ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, modelManager)
261-
router.Handle(ollama.APIPrefix+"/", ollamaHandler)
262-
263-
// Add Anthropic Messages API compatibility layer
264-
anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, modelManager)
265-
router.Handle(anthropic.APIPrefix+"/", anthropicHandler)
266-
267-
// Register root handler LAST - it will only catch exact "/" requests that don't match other patterns
268-
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
269-
// Only respond to exact root path
270-
if r.URL.Path != "/" {
271-
http.NotFound(w, r)
272-
return
273-
}
274-
w.WriteHeader(http.StatusOK)
275-
_, _ = w.Write([]byte("Docker Model Runner is running"))
152+
IncludeResponsesAPI: true,
153+
ExtraRoutes: func(r *routing.NormalizedServeMux, s *routing.Service) {
154+
// Root handler – only catches exact "/" requests
155+
r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
156+
if req.URL.Path != "/" {
157+
http.NotFound(w, req)
158+
return
159+
}
160+
w.WriteHeader(http.StatusOK)
161+
_, _ = w.Write([]byte("Docker Model Runner is running"))
162+
})
163+
164+
// Metrics endpoint
165+
if os.Getenv("DISABLE_METRICS") != "1" {
166+
metricsHandler := metrics.NewAggregatedMetricsHandler(
167+
log.WithField("component", "metrics"),
168+
s.SchedulerHTTP,
169+
)
170+
r.Handle("/metrics", metricsHandler)
171+
log.Info("Metrics endpoint enabled at /metrics")
172+
} else {
173+
log.Info("Metrics endpoint disabled")
174+
}
175+
},
276176
})
277177

278-
// Add metrics endpoint if enabled
279-
if os.Getenv("DISABLE_METRICS") != "1" {
280-
metricsHandler := metrics.NewAggregatedMetricsHandler(
281-
log.WithField("component", "metrics"),
282-
schedulerHTTP,
283-
)
284-
router.Handle("/metrics", metricsHandler)
285-
log.Info("Metrics endpoint enabled at /metrics")
286-
} else {
287-
log.Info("Metrics endpoint disabled")
288-
}
289-
290178
server := &http.Server{
291-
Handler: router,
179+
Handler: svc.Router,
292180
ReadHeaderTimeout: 10 * time.Second,
293181
}
294182
serverErrors := make(chan error, 1)
@@ -358,7 +246,7 @@ func main() {
358246

359247
tlsServer = &http.Server{
360248
Addr: ":" + tlsPort,
361-
Handler: router,
249+
Handler: svc.Router,
362250
TLSConfig: tlsConfig,
363251
ReadHeaderTimeout: 10 * time.Second,
364252
}
@@ -377,7 +265,7 @@ func main() {
377265

378266
schedulerErrors := make(chan error, 1)
379267
go func() {
380-
schedulerErrors <- scheduler.Run(ctx)
268+
schedulerErrors <- svc.Scheduler.Run(ctx)
381269
}()
382270

383271
var tlsServerErrorsChan <-chan error

pkg/inference/backends/vllmmetal/vllmmetal.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/docker/model-runner/pkg/inference/platform"
2020
"github.com/docker/model-runner/pkg/internal/dockerhub"
2121
"github.com/docker/model-runner/pkg/logging"
22+
"github.com/sirupsen/logrus"
2223
)
2324

2425
const (
@@ -71,6 +72,22 @@ func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Log
7172
}, nil
7273
}
7374

75+
// TryRegister initializes the vllm-metal backend if the platform supports it
76+
// and registers it in the provided backends map. It returns the backend names
77+
// whose installation should be deferred until explicitly requested.
78+
func TryRegister(log logging.Logger, modelManager *models.Manager, backends map[string]inference.Backend, serverPath string) []string {
79+
if !platform.SupportsVLLMMetal() {
80+
return nil
81+
}
82+
backend, err := New(log, modelManager, log.WithFields(logrus.Fields{"component": Name}), serverPath)
83+
if err != nil {
84+
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
85+
return nil
86+
}
87+
backends[Name] = backend
88+
return []string{Name}
89+
}
90+
7491
// Name implements inference.Backend.Name.
7592
func (v *vllmMetal) Name() string {
7693
return Name

pkg/routing/backends.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package routing
2+
3+
import (
4+
"github.com/docker/model-runner/pkg/inference"
5+
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
6+
"github.com/docker/model-runner/pkg/inference/backends/mlx"
7+
"github.com/docker/model-runner/pkg/inference/backends/vllm"
8+
"github.com/docker/model-runner/pkg/inference/config"
9+
"github.com/docker/model-runner/pkg/inference/models"
10+
"github.com/docker/model-runner/pkg/logging"
11+
)
12+
13+
// BackendsConfig configures which inference backends to create and how.
14+
type BackendsConfig struct {
15+
// Log is the main logger passed to each backend.
16+
Log logging.Logger
17+
18+
// ServerLogFactory creates the server-process logger for a backend.
19+
// If nil, Log is used directly as the server logger.
20+
ServerLogFactory func(backendName string) logging.Logger
21+
22+
// LlamaCpp settings (always included).
23+
LlamaCppVendoredPath string
24+
LlamaCppUpdatedPath string
25+
LlamaCppConfig config.BackendConfig
26+
27+
// Optional backends and their custom server paths.
28+
IncludeMLX bool
29+
MLXPath string
30+
31+
IncludeVLLM bool
32+
VLLMPath string
33+
}
34+
35+
// DefaultBackendDefs returns BackendDef entries for the configured backends.
36+
// It always includes llamacpp; MLX and vLLM are included based on the
37+
// boolean flags.
38+
func DefaultBackendDefs(cfg BackendsConfig) []BackendDef {
39+
sl := func(name string) logging.Logger {
40+
if cfg.ServerLogFactory != nil {
41+
return cfg.ServerLogFactory(name)
42+
}
43+
return cfg.Log
44+
}
45+
46+
defs := []BackendDef{
47+
{Name: llamacpp.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
48+
return llamacpp.New(cfg.Log, mm, sl(llamacpp.Name), cfg.LlamaCppVendoredPath, cfg.LlamaCppUpdatedPath, cfg.LlamaCppConfig)
49+
}},
50+
}
51+
52+
if cfg.IncludeMLX {
53+
defs = append(defs, BackendDef{Name: mlx.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
54+
return mlx.New(cfg.Log, mm, sl(mlx.Name), nil, cfg.MLXPath)
55+
}})
56+
}
57+
58+
if cfg.IncludeVLLM {
59+
defs = append(defs, BackendDef{Name: vllm.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
60+
return vllm.New(cfg.Log, mm, sl(vllm.Name), nil, cfg.VLLMPath)
61+
}})
62+
}
63+
64+
return defs
65+
}

0 commit comments

Comments
 (0)