Skip to content

Commit 0652fa5

Browse files
committed
Automatic retries
When you run cmd/cli/model-cli pull gemma3 and the download gets interrupted: 1. The CLI detects the interruption (network error, timeout, etc.) 2. It waits 1 second and automatically retries 3. If it fails again, it waits 2 seconds and retries 4. If it fails a third time, it waits 4 seconds and retries 5. After 3 failed retries, it reports the error to the user The user will see messages like: Retrying download (attempt 1/3) in 1s... Retrying download (attempt 2/3) in 2s... Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 77e5c56 commit 0652fa5

File tree

2 files changed

+309
-40
lines changed

2 files changed

+309
-40
lines changed

cmd/cli/desktop/desktop.go

Lines changed: 139 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -106,61 +106,160 @@ func (c *Client) Status() Status {
106106

107107
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) {
108108
model = normalizeHuggingFaceModelName(model)
109-
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
110-
if err != nil {
111-
return "", false, fmt.Errorf("error marshaling request: %w", err)
109+
110+
return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) {
111+
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
112+
if err != nil {
113+
// Marshaling errors are not retryable
114+
return "", false, fmt.Errorf("error marshaling request: %w", err), false
115+
}
116+
117+
createPath := inference.ModelsPrefix + "/create"
118+
resp, err := c.doRequest(
119+
http.MethodPost,
120+
createPath,
121+
bytes.NewReader(jsonData),
122+
)
123+
if err != nil {
124+
// Only retry on network errors, not on client errors
125+
if isRetryableError(err) {
126+
return "", false, c.handleQueryError(err, createPath), true
127+
}
128+
return "", false, c.handleQueryError(err, createPath), false
129+
}
130+
// Close response body explicitly at the end of this attempt, not deferred
131+
defer resp.Body.Close()
132+
133+
if resp.StatusCode != http.StatusOK {
134+
body, _ := io.ReadAll(resp.Body)
135+
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
136+
// Only retry on server errors (5xx), not client errors (4xx)
137+
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
138+
return "", false, err, shouldRetry
139+
}
140+
141+
// Use Docker-style progress display
142+
message, shown, err := DisplayProgress(resp.Body, printer)
143+
if err != nil {
144+
// Retry on progress display errors (likely network interruption)
145+
shouldRetry := isRetryableError(err)
146+
return "", shown, err, shouldRetry
147+
}
148+
149+
return message, shown, nil, false
150+
})
151+
}
152+
153+
// isRetryableError determines if an error is retryable (network-related)
154+
func isRetryableError(err error) bool {
155+
if err == nil {
156+
return false
112157
}
113158

114-
createPath := inference.ModelsPrefix + "/create"
115-
resp, err := c.doRequest(
116-
http.MethodPost,
117-
createPath,
118-
bytes.NewReader(jsonData),
119-
)
120-
if err != nil {
121-
return "", false, c.handleQueryError(err, createPath)
159+
// First check for specific error types using errors.Is
160+
if errors.Is(err, context.DeadlineExceeded) ||
161+
errors.Is(err, io.ErrUnexpectedEOF) ||
162+
errors.Is(err, io.EOF) ||
163+
errors.Is(err, ErrServiceUnavailable) {
164+
return true
122165
}
123-
defer resp.Body.Close()
124166

125-
if resp.StatusCode != http.StatusOK {
126-
body, _ := io.ReadAll(resp.Body)
127-
return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
167+
// Fall back to string matching for network errors that don't have specific types
168+
// This is necessary because many network errors are only available as strings
169+
errStr := err.Error()
170+
retryablePatterns := []string{
171+
"connection refused",
172+
"connection reset",
173+
"broken pipe",
174+
"timeout",
175+
"temporary failure",
176+
"no such host",
177+
"no route to host",
178+
"network is unreachable",
179+
"i/o timeout",
128180
}
129181

130-
// Use Docker-style progress display
131-
message, progressShown, err := DisplayProgress(resp.Body, printer)
132-
if err != nil {
133-
return "", progressShown, err
182+
for _, pattern := range retryablePatterns {
183+
if strings.Contains(strings.ToLower(errStr), pattern) {
184+
return true
185+
}
134186
}
135187

136-
return message, progressShown, nil
188+
return false
189+
}
190+
191+
// withRetries executes an operation with automatic retry logic for transient failures
192+
func (c *Client) withRetries(
193+
operationName string,
194+
maxRetries int,
195+
printer standalone.StatusPrinter,
196+
operation func(attempt int) (message string, progressShown bool, err error, shouldRetry bool),
197+
) (string, bool, error) {
198+
var lastErr error
199+
var progressShown bool
200+
201+
for attempt := 0; attempt <= maxRetries; attempt++ {
202+
if attempt > 0 {
203+
// Calculate exponential backoff: 2^(attempt-1) seconds (1s, 2s, 4s)
204+
backoffDuration := time.Duration(1<<uint(attempt-1)) * time.Second
205+
printer.PrintErrf("Retrying %s (attempt %d/%d) in %v...\n", operationName, attempt, maxRetries, backoffDuration)
206+
time.Sleep(backoffDuration)
207+
}
208+
209+
message, shown, err, shouldRetry := operation(attempt)
210+
progressShown = progressShown || shown
211+
212+
if err == nil {
213+
return message, progressShown, nil
214+
}
215+
216+
lastErr = err
217+
if !shouldRetry {
218+
return "", progressShown, err
219+
}
220+
}
221+
222+
return "", progressShown, fmt.Errorf("failed to %s after %d retries: %w", operationName, maxRetries, lastErr)
137223
}
138224

139225
func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
140226
model = normalizeHuggingFaceModelName(model)
141-
pushPath := inference.ModelsPrefix + "/" + model + "/push"
142-
resp, err := c.doRequest(
143-
http.MethodPost,
144-
pushPath,
145-
nil, // Assuming no body is needed for the push request
146-
)
147-
if err != nil {
148-
return "", false, c.handleQueryError(err, pushPath)
149-
}
150-
defer resp.Body.Close()
151227

152-
if resp.StatusCode != http.StatusOK {
153-
body, _ := io.ReadAll(resp.Body)
154-
return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
155-
}
228+
return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) {
229+
pushPath := inference.ModelsPrefix + "/" + model + "/push"
230+
resp, err := c.doRequest(
231+
http.MethodPost,
232+
pushPath,
233+
nil, // Assuming no body is needed for the push request
234+
)
235+
if err != nil {
236+
// Only retry on network errors, not on client errors
237+
if isRetryableError(err) {
238+
return "", false, c.handleQueryError(err, pushPath), true
239+
}
240+
return "", false, c.handleQueryError(err, pushPath), false
241+
}
242+
// Close response body explicitly at the end of this attempt, not deferred
243+
defer resp.Body.Close()
156244

157-
// Use Docker-style progress display
158-
message, progressShown, err := DisplayProgress(resp.Body, printer)
159-
if err != nil {
160-
return "", progressShown, err
161-
}
245+
if resp.StatusCode != http.StatusOK {
246+
body, _ := io.ReadAll(resp.Body)
247+
err := fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
248+
// Only retry on server errors (5xx), not client errors (4xx)
249+
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
250+
return "", false, err, shouldRetry
251+
}
252+
253+
// Use Docker-style progress display
254+
message, shown, err := DisplayProgress(resp.Body, printer)
255+
if err != nil {
256+
// Retry on progress display errors (likely network interruption)
257+
shouldRetry := isRetryableError(err)
258+
return "", shown, err, shouldRetry
259+
}
162260

163-
return message, progressShown, nil
261+
return message, shown, nil, false
262+
})
164263
}
165264

166265
func (c *Client) List() ([]dmrm.Model, error) {

cmd/cli/desktop/desktop_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package desktop
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
7+
"errors"
68
"io"
79
"net/http"
810
"testing"
@@ -228,3 +230,171 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) {
228230
assert.Equal(t, expectedLowercase, model.ID)
229231
}
230232

233+
func TestPullRetryOnNetworkError(t *testing.T) {
234+
ctrl := gomock.NewController(t)
235+
defer ctrl.Finish()
236+
237+
modelName := "test-model"
238+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
239+
mockContext := NewContextForMock(mockClient)
240+
client := New(mockContext)
241+
242+
// First two attempts fail with network error, third succeeds
243+
gomock.InOrder(
244+
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.ErrUnexpectedEOF),
245+
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.ErrUnexpectedEOF),
246+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
247+
StatusCode: http.StatusOK,
248+
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
249+
}, nil),
250+
)
251+
252+
printer := NewSimplePrinter(func(s string) {})
253+
_, _, err := client.Pull(modelName, false, printer)
254+
assert.NoError(t, err)
255+
}
256+
257+
func TestPullNoRetryOn4xxError(t *testing.T) {
258+
ctrl := gomock.NewController(t)
259+
defer ctrl.Finish()
260+
261+
modelName := "test-model"
262+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
263+
mockContext := NewContextForMock(mockClient)
264+
client := New(mockContext)
265+
266+
// Should not retry on 404 (client error)
267+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
268+
StatusCode: http.StatusNotFound,
269+
Body: io.NopCloser(bytes.NewBufferString("Model not found")),
270+
}, nil).Times(1)
271+
272+
printer := NewSimplePrinter(func(s string) {})
273+
_, _, err := client.Pull(modelName, false, printer)
274+
assert.Error(t, err)
275+
assert.Contains(t, err.Error(), "Model not found")
276+
}
277+
278+
func TestPullRetryOn5xxError(t *testing.T) {
279+
ctrl := gomock.NewController(t)
280+
defer ctrl.Finish()
281+
282+
modelName := "test-model"
283+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
284+
mockContext := NewContextForMock(mockClient)
285+
client := New(mockContext)
286+
287+
// First attempt fails with 500, second succeeds
288+
gomock.InOrder(
289+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
290+
StatusCode: http.StatusInternalServerError,
291+
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
292+
}, nil),
293+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
294+
StatusCode: http.StatusOK,
295+
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
296+
}, nil),
297+
)
298+
299+
printer := NewSimplePrinter(func(s string) {})
300+
_, _, err := client.Pull(modelName, false, printer)
301+
assert.NoError(t, err)
302+
}
303+
304+
func TestPullRetryOnServiceUnavailable(t *testing.T) {
305+
ctrl := gomock.NewController(t)
306+
defer ctrl.Finish()
307+
308+
modelName := "test-model"
309+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
310+
mockContext := NewContextForMock(mockClient)
311+
client := New(mockContext)
312+
313+
// First attempt fails with 503 (converted to ErrServiceUnavailable), second succeeds
314+
// Note: 503 is handled specially in doRequestWithAuthContext and returns ErrServiceUnavailable
315+
gomock.InOrder(
316+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
317+
StatusCode: http.StatusServiceUnavailable,
318+
Body: io.NopCloser(bytes.NewBufferString("Service temporarily unavailable")),
319+
}, nil),
320+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
321+
StatusCode: http.StatusOK,
322+
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
323+
}, nil),
324+
)
325+
326+
printer := NewSimplePrinter(func(s string) {})
327+
_, _, err := client.Pull(modelName, false, printer)
328+
assert.NoError(t, err)
329+
}
330+
331+
func TestPullMaxRetriesExhausted(t *testing.T) {
332+
ctrl := gomock.NewController(t)
333+
defer ctrl.Finish()
334+
335+
modelName := "test-model"
336+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
337+
mockContext := NewContextForMock(mockClient)
338+
client := New(mockContext)
339+
340+
// All 4 attempts (1 initial + 3 retries) fail with network error
341+
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.EOF).Times(4)
342+
343+
printer := NewSimplePrinter(func(s string) {})
344+
_, _, err := client.Pull(modelName, false, printer)
345+
assert.Error(t, err)
346+
assert.Contains(t, err.Error(), "failed to download after 3 retries")
347+
}
348+
349+
func TestPushRetryOnNetworkError(t *testing.T) {
350+
ctrl := gomock.NewController(t)
351+
defer ctrl.Finish()
352+
353+
modelName := "test-model"
354+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
355+
mockContext := NewContextForMock(mockClient)
356+
client := New(mockContext)
357+
358+
// First attempt fails with network error, second succeeds
359+
gomock.InOrder(
360+
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.ErrUnexpectedEOF),
361+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
362+
StatusCode: http.StatusOK,
363+
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
364+
}, nil),
365+
)
366+
367+
printer := NewSimplePrinter(func(s string) {})
368+
_, _, err := client.Push(modelName, printer)
369+
assert.NoError(t, err)
370+
}
371+
372+
func TestIsRetryableError(t *testing.T) {
373+
tests := []struct {
374+
name string
375+
err error
376+
expected bool
377+
}{
378+
{"nil error", nil, false},
379+
{"EOF error", io.EOF, true},
380+
{"UnexpectedEOF error", io.ErrUnexpectedEOF, true},
381+
{"connection reset in string", errors.New("some error: connection reset by peer"), true},
382+
{"timeout in string", errors.New("operation failed: i/o timeout"), true},
383+
{"connection refused", errors.New("dial tcp: connection refused"), true},
384+
{"broken pipe", errors.New("write: broken pipe"), true},
385+
{"network unreachable", errors.New("network is unreachable"), true},
386+
{"no such host", errors.New("lookup failed: no such host"), true},
387+
{"no route to host", errors.New("read tcp: no route to host"), true},
388+
{"generic non-retryable error", errors.New("a generic non-retryable error"), false},
389+
{"service unavailable error", ErrServiceUnavailable, true},
390+
{"deadline exceeded", context.DeadlineExceeded, true},
391+
}
392+
393+
for _, tt := range tests {
394+
t.Run(tt.name, func(t *testing.T) {
395+
result := isRetryableError(tt.err)
396+
assert.Equal(t, tt.expected, result)
397+
})
398+
}
399+
}
400+

0 commit comments

Comments
 (0)