Skip to content

Commit c22d482

Browse files
committed
Start implementing /api
From ollama, this now works: MODEL_RUNNER_PORT=13434 make run OLLAMA_HOST="127.0.0.1:13434" ollama run ai/smollm2 OLLAMA_HOST="127.0.0.1:13434" ollama ls OLLAMA_HOST="127.0.0.1:13434" ollama ps OLLAMA_HOST="127.0.0.1:13434" ollama stop ai/smollm2 Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent e5aef7f commit c22d482

File tree

5 files changed

+1003
-1
lines changed

5 files changed

+1003
-1
lines changed

main.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/docker/model-runner/pkg/inference/scheduling"
2121
"github.com/docker/model-runner/pkg/metrics"
2222
"github.com/docker/model-runner/pkg/middleware"
23+
"github.com/docker/model-runner/pkg/ollama"
2324
"github.com/docker/model-runner/pkg/routing"
2425
"github.com/sirupsen/logrus"
2526
)
@@ -157,6 +158,21 @@ func main() {
157158
// Add /v1 as an alias for /engines/v1
158159
router.Handle("/v1/", &middleware.V1AliasHandler{Handler: scheduler})
159160

161+
// Add Ollama API compatibility layer (only register with trailing slash to catch sub-paths)
162+
ollamaHandler := ollama.NewHandler(log, modelManager, scheduler, nil)
163+
router.Handle(ollama.APIPrefix+"/", ollamaHandler)
164+
165+
// Register root handler LAST - it will only catch exact "/" requests that don't match other patterns
166+
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
167+
// Only respond to exact root path
168+
if r.URL.Path != "/" {
169+
http.NotFound(w, r)
170+
return
171+
}
172+
w.WriteHeader(http.StatusOK)
173+
w.Write([]byte("Docker Model Runner is running"))
174+
})
175+
160176
// Add metrics endpoint if enabled
161177
if os.Getenv("DISABLE_METRICS") != "1" {
162178
metricsHandler := metrics.NewAggregatedMetricsHandler(

pkg/distribution/distribution/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func (c *Client) GetModel(reference string) (types.Model, error) {
276276
model, err := c.store.Read(reference)
277277
if err != nil {
278278
c.log.Errorln("Failed to get model:", err, "reference:", utils.SanitizeForLog(reference))
279-
return nil, fmt.Errorf("get model '%q': %w", reference, err)
279+
return nil, fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err)
280280
}
281281

282282
return model, nil

pkg/inference/models/manager.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,31 @@ func (m *Manager) IsModelInStore(ref string) (bool, error) {
887887
return m.distributionClient.IsModelInStore(ref)
888888
}
889889

890+
// GetModels returns all models.
891+
func (m *Manager) GetModels() ([]*Model, error) {
892+
if m.distributionClient == nil {
893+
return nil, fmt.Errorf("model distribution service unavailable")
894+
}
895+
896+
// Query models.
897+
models, err := m.distributionClient.ListModels()
898+
if err != nil {
899+
return nil, fmt.Errorf("error while listing models: %w", err)
900+
}
901+
902+
apiModels := make([]*Model, 0, len(models))
903+
for _, model := range models {
904+
apiModel, err := ToModel(model)
905+
if err != nil {
906+
m.log.Warnf("error while converting model, skipping: %v", err)
907+
continue
908+
}
909+
apiModels = append(apiModels, apiModel)
910+
}
911+
912+
return apiModels, nil
913+
}
914+
890915
// GetModel returns a single model.
891916
func (m *Manager) GetModel(ref string) (types.Model, error) {
892917
model, err := m.distributionClient.GetModel(ref)

pkg/inference/scheduling/scheduler.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) {
314314
}
315315
}
316316

317+
// GetRunningBackendsInfo returns information about all running backends as a slice
318+
func (s *Scheduler) GetRunningBackendsInfo(ctx context.Context) []BackendStatus {
319+
return s.getLoaderStatus(ctx)
320+
}
321+
317322
// getLoaderStatus returns information about all running backends managed by the loader
318323
func (s *Scheduler) getLoaderStatus(ctx context.Context) []BackendStatus {
319324
if !s.loader.lock(ctx) {

0 commit comments

Comments
 (0)