Skip to content

Commit 9a12636

Browse files
committed
refactor: extract shared infrastructure into pkg/routing
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent a1e9a79 commit 9a12636

File tree

7 files changed

+405
-205
lines changed

7 files changed

+405
-205
lines changed

main.go

Lines changed: 71 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,13 @@ import (
1313
"syscall"
1414
"time"
1515

16-
"github.com/docker/model-runner/pkg/anthropic"
1716
"github.com/docker/model-runner/pkg/inference"
1817
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
1918
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
20-
"github.com/docker/model-runner/pkg/inference/backends/mlx"
2119
"github.com/docker/model-runner/pkg/inference/backends/sglang"
22-
"github.com/docker/model-runner/pkg/inference/backends/vllm"
23-
"github.com/docker/model-runner/pkg/inference/backends/vllmmetal"
2420
"github.com/docker/model-runner/pkg/inference/config"
2521
"github.com/docker/model-runner/pkg/inference/models"
26-
"github.com/docker/model-runner/pkg/inference/platform"
27-
"github.com/docker/model-runner/pkg/inference/scheduling"
2822
"github.com/docker/model-runner/pkg/metrics"
29-
"github.com/docker/model-runner/pkg/middleware"
30-
"github.com/docker/model-runner/pkg/ollama"
31-
"github.com/docker/model-runner/pkg/responses"
3223
"github.com/docker/model-runner/pkg/routing"
3324
modeltls "github.com/docker/model-runner/pkg/tls"
3425
"github.com/sirupsen/logrus"
@@ -100,17 +91,6 @@ func main() {
10091
}
10192
baseTransport.Proxy = http.ProxyFromEnvironment
10293

103-
clientConfig := models.ClientConfig{
104-
StoreRootPath: modelPath,
105-
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
106-
Transport: baseTransport,
107-
}
108-
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
109-
modelHandler := models.NewHTTPHandler(
110-
log,
111-
modelManager,
112-
nil,
113-
)
11494
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
11595
if vllmServerPath != "" {
11696
log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath)
@@ -128,176 +108,84 @@ func main() {
128108
// Create llama.cpp configuration from environment variables
129109
llamaCppConfig := createLlamaCppConfigFromEnv()
130110

131-
llamaCppBackend, err := llamacpp.New(
132-
log,
133-
modelManager,
134-
log.WithFields(logrus.Fields{"component": llamacpp.Name}),
135-
llamaServerPath,
136-
func() string {
137-
wd, _ := os.Getwd()
138-
d := filepath.Join(wd, "updated-inference", "bin")
139-
_ = os.MkdirAll(d, 0o755)
140-
return d
141-
}(),
142-
llamaCppConfig,
143-
)
144-
if err != nil {
145-
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
146-
}
147-
148-
vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath)
149-
if err != nil {
150-
log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err)
151-
}
152-
153-
mlxBackend, err := mlx.New(
154-
log,
155-
modelManager,
156-
log.WithFields(logrus.Fields{"component": mlx.Name}),
157-
nil,
158-
mlxServerPath,
159-
)
160-
if err != nil {
161-
log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err)
162-
}
163-
164-
sglangBackend, err := sglang.New(
165-
log,
166-
modelManager,
167-
log.WithFields(logrus.Fields{"component": sglang.Name}),
168-
nil,
169-
sglangServerPath,
170-
)
171-
if err != nil {
172-
log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err)
173-
}
174-
175-
diffusersBackend, err := diffusers.New(
176-
log,
177-
modelManager,
178-
log.WithFields(logrus.Fields{"component": diffusers.Name}),
179-
nil,
180-
diffusersServerPath,
181-
)
182-
183-
if err != nil {
184-
log.Fatalf("unable to initialize diffusers backend: %v", err)
185-
}
186-
187-
var vllmMetalBackend inference.Backend
188-
if platform.SupportsVLLMMetal() {
189-
vllmMetalBackend, err = vllmmetal.New(
190-
log,
191-
modelManager,
192-
log.WithFields(logrus.Fields{"component": vllmmetal.Name}),
193-
vllmMetalServerPath,
194-
)
195-
if err != nil {
196-
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
197-
}
198-
}
199-
200-
backends := map[string]inference.Backend{
201-
llamacpp.Name: llamaCppBackend,
202-
mlx.Name: mlxBackend,
203-
sglang.Name: sglangBackend,
204-
diffusers.Name: diffusersBackend,
205-
}
206-
registerVLLMBackend(backends, vllmBackend)
207-
208-
if vllmMetalBackend != nil {
209-
backends[vllmmetal.Name] = vllmMetalBackend
210-
}
211-
212-
// Backends whose installation is deferred until explicitly requested.
213-
var deferredBackends []string
214-
if vllmMetalBackend != nil {
215-
deferredBackends = append(deferredBackends, vllmmetal.Name)
216-
}
111+
updatedServerPath := func() string {
112+
wd, _ := os.Getwd()
113+
d := filepath.Join(wd, "updated-inference", "bin")
114+
_ = os.MkdirAll(d, 0o755)
115+
return d
116+
}()
217117

218-
scheduler := scheduling.NewScheduler(
219-
log,
220-
backends,
221-
llamaCppBackend,
222-
modelManager,
223-
http.DefaultClient,
224-
metrics.NewTracker(
118+
svc := routing.NewService(routing.ServiceConfig{
119+
Log: log,
120+
ClientConfig: models.ClientConfig{
121+
StoreRootPath: modelPath,
122+
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
123+
Transport: baseTransport,
124+
},
125+
Backends: append(append(
126+
routing.DefaultBackendDefs(routing.BackendsConfig{
127+
Log: log,
128+
LlamaCppVendoredPath: llamaServerPath,
129+
LlamaCppUpdatedPath: updatedServerPath,
130+
LlamaCppConfig: llamaCppConfig,
131+
IncludeMLX: true,
132+
MLXPath: mlxServerPath,
133+
}),
134+
routing.BackendDef{Name: sglang.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
135+
return sglang.New(log, mm, log.WithFields(logrus.Fields{"component": sglang.Name}), nil, sglangServerPath)
136+
}},
137+
routing.BackendDef{Name: diffusers.Name, Init: func(mm *models.Manager) (inference.Backend, error) {
138+
return diffusers.New(log, mm, log.WithFields(logrus.Fields{"component": diffusers.Name}), nil, diffusersServerPath)
139+
}},
140+
), vllmBackendDefs(log, vllmServerPath)...),
141+
OnBackendError: func(name string, err error) {
142+
log.Fatalf("unable to initialize %s backend: %v", name, err)
143+
},
144+
DefaultBackendName: llamacpp.Name,
145+
VLLMMetalServerPath: vllmMetalServerPath,
146+
HTTPClient: http.DefaultClient,
147+
MetricsTracker: metrics.NewTracker(
225148
http.DefaultClient,
226149
log.WithField("component", "metrics"),
227150
"",
228151
false,
229152
),
230-
deferredBackends,
231-
)
232-
233-
// Create the HTTP handler for the scheduler
234-
schedulerHTTP := scheduling.NewHTTPHandler(scheduler, modelHandler, nil)
235-
236-
router := routing.NewNormalizedServeMux()
237-
238-
// Register path prefixes to forward all HTTP methods (including OPTIONS) to components
239-
// Components handle method routing internally
240-
// Register both with and without trailing slash to avoid redirects
241-
router.Handle(inference.ModelsPrefix, modelHandler)
242-
router.Handle(inference.ModelsPrefix+"/", modelHandler)
243-
router.Handle(inference.InferencePrefix+"/", schedulerHTTP)
244-
// Add OpenAI Responses API compatibility layer
245-
responsesHandler := responses.NewHTTPHandler(log, schedulerHTTP, nil)
246-
router.Handle(responses.APIPrefix+"/", responsesHandler)
247-
router.Handle(responses.APIPrefix, responsesHandler) // Also register for exact match without trailing slash
248-
router.Handle("/v1"+responses.APIPrefix+"/", responsesHandler)
249-
router.Handle("/v1"+responses.APIPrefix, responsesHandler)
250-
// Also register Responses API under inference prefix to support all inference engines
251-
router.Handle(inference.InferencePrefix+responses.APIPrefix+"/", responsesHandler)
252-
router.Handle(inference.InferencePrefix+responses.APIPrefix, responsesHandler)
253-
254-
// Add path aliases: /v1 -> /engines/v1, /rerank -> /engines/rerank, /score -> /engines/score.
255-
aliasHandler := &middleware.AliasHandler{Handler: schedulerHTTP}
256-
router.Handle("/v1/", aliasHandler)
257-
router.Handle("/rerank", aliasHandler)
258-
router.Handle("/score", aliasHandler)
259-
260-
// Add Ollama API compatibility layer (only register with trailing slash to catch sub-paths)
261-
ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, modelManager)
262-
router.Handle(ollama.APIPrefix+"/", ollamaHandler)
263-
264-
// Add Anthropic Messages API compatibility layer
265-
anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, modelManager)
266-
router.Handle(anthropic.APIPrefix+"/", anthropicHandler)
267-
268-
// Register /version endpoint
269-
router.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
270-
w.Header().Set("Content-Type", "application/json")
271-
if err := json.NewEncoder(w).Encode(map[string]string{"version": Version}); err != nil {
272-
log.Warnf("failed to write version response: %v", err)
273-
}
274-
})
275-
276-
// Register root handler LAST - it will only catch exact "/" requests that don't match other patterns
277-
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
278-
// Only respond to exact root path
279-
if r.URL.Path != "/" {
280-
http.NotFound(w, r)
281-
return
282-
}
283-
w.WriteHeader(http.StatusOK)
284-
_, _ = w.Write([]byte("Docker Model Runner is running"))
153+
IncludeResponsesAPI: true,
154+
ExtraRoutes: func(r *routing.NormalizedServeMux, s *routing.Service) {
155+
// Root handler – only catches exact "/" requests
156+
r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
157+
if req.URL.Path != "/" {
158+
http.NotFound(w, req)
159+
return
160+
}
161+
w.WriteHeader(http.StatusOK)
162+
_, _ = w.Write([]byte("Docker Model Runner is running"))
163+
})
164+
165+
// Version endpoint
166+
r.HandleFunc("/version", func(w http.ResponseWriter, req *http.Request) {
167+
w.Header().Set("Content-Type", "application/json")
168+
if err := json.NewEncoder(w).Encode(map[string]string{"version": Version}); err != nil {
169+
log.Warnf("failed to write version response: %v", err)
170+
}
171+
})
172+
173+
// Metrics endpoint
174+
if os.Getenv("DISABLE_METRICS") != "1" {
175+
metricsHandler := metrics.NewAggregatedMetricsHandler(
176+
log.WithField("component", "metrics"),
177+
s.SchedulerHTTP,
178+
)
179+
r.Handle("/metrics", metricsHandler)
180+
log.Info("Metrics endpoint enabled at /metrics")
181+
} else {
182+
log.Info("Metrics endpoint disabled")
183+
}
184+
},
285185
})
286186

287-
// Add metrics endpoint if enabled
288-
if os.Getenv("DISABLE_METRICS") != "1" {
289-
metricsHandler := metrics.NewAggregatedMetricsHandler(
290-
log.WithField("component", "metrics"),
291-
schedulerHTTP,
292-
)
293-
router.Handle("/metrics", metricsHandler)
294-
log.Info("Metrics endpoint enabled at /metrics")
295-
} else {
296-
log.Info("Metrics endpoint disabled")
297-
}
298-
299187
server := &http.Server{
300-
Handler: router,
188+
Handler: svc.Router,
301189
ReadHeaderTimeout: 10 * time.Second,
302190
}
303191
serverErrors := make(chan error, 1)
@@ -367,7 +255,7 @@ func main() {
367255

368256
tlsServer = &http.Server{
369257
Addr: ":" + tlsPort,
370-
Handler: router,
258+
Handler: svc.Router,
371259
TLSConfig: tlsConfig,
372260
ReadHeaderTimeout: 10 * time.Second,
373261
}
@@ -386,7 +274,7 @@ func main() {
386274

387275
schedulerErrors := make(chan error, 1)
388276
go func() {
389-
schedulerErrors <- scheduler.Run(ctx)
277+
schedulerErrors <- svc.Scheduler.Run(ctx)
390278
}()
391279

392280
var tlsServerErrorsChan <-chan error

pkg/inference/backends/vllmmetal/vllmmetal.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/docker/model-runner/pkg/inference/platform"
2020
"github.com/docker/model-runner/pkg/internal/dockerhub"
2121
"github.com/docker/model-runner/pkg/logging"
22+
"github.com/sirupsen/logrus"
2223
)
2324

2425
const (
@@ -71,6 +72,22 @@ func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Log
7172
}, nil
7273
}
7374

75+
// TryRegister initializes the vllm-metal backend if the platform supports it
76+
// and registers it in the provided backends map. It returns the backend names
77+
// whose installation should be deferred until explicitly requested.
78+
func TryRegister(log logging.Logger, modelManager *models.Manager, backends map[string]inference.Backend, serverPath string) []string {
79+
if !platform.SupportsVLLMMetal() {
80+
return nil
81+
}
82+
backend, err := New(log, modelManager, log.WithFields(logrus.Fields{"component": Name}), serverPath)
83+
if err != nil {
84+
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
85+
return nil
86+
}
87+
backends[Name] = backend
88+
return []string{Name}
89+
}
90+
7491
// Name implements inference.Backend.Name.
7592
func (v *vllmMetal) Name() string {
7693
return Name

0 commit comments

Comments
 (0)