@@ -106,61 +106,160 @@ func (c *Client) Status() Status {
106106
107107func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , printer standalone.StatusPrinter ) (string , bool , error ) {
108108 model = normalizeHuggingFaceModelName (model )
109- jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
110- if err != nil {
111- return "" , false , fmt .Errorf ("error marshaling request: %w" , err )
109+
110+ return c .withRetries ("download" , 3 , printer , func (attempt int ) (string , bool , error , bool ) {
111+ jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
112+ if err != nil {
113+ // Marshaling errors are not retryable
114+ return "" , false , fmt .Errorf ("error marshaling request: %w" , err ), false
115+ }
116+
117+ createPath := inference .ModelsPrefix + "/create"
118+ resp , err := c .doRequest (
119+ http .MethodPost ,
120+ createPath ,
121+ bytes .NewReader (jsonData ),
122+ )
123+ if err != nil {
124+ // Only retry on network errors, not on client errors
125+ if isRetryableError (err ) {
126+ return "" , false , c .handleQueryError (err , createPath ), true
127+ }
128+ return "" , false , c .handleQueryError (err , createPath ), false
129+ }
130+ // Close response body explicitly at the end of this attempt, not deferred
131+ defer resp .Body .Close ()
132+
133+ if resp .StatusCode != http .StatusOK {
134+ body , _ := io .ReadAll (resp .Body )
135+ err := fmt .Errorf ("pulling %s failed with status %s: %s" , model , resp .Status , string (body ))
136+ // Only retry on server errors (5xx), not client errors (4xx)
137+ shouldRetry := resp .StatusCode >= 500 && resp .StatusCode < 600
138+ return "" , false , err , shouldRetry
139+ }
140+
141+ // Use Docker-style progress display
142+ message , shown , err := DisplayProgress (resp .Body , printer )
143+ if err != nil {
144+ // Retry on progress display errors (likely network interruption)
145+ shouldRetry := isRetryableError (err )
146+ return "" , shown , err , shouldRetry
147+ }
148+
149+ return message , shown , nil , false
150+ })
151+ }
152+
153+ // isRetryableError determines if an error is retryable (network-related)
154+ func isRetryableError (err error ) bool {
155+ if err == nil {
156+ return false
112157 }
113158
114- createPath := inference .ModelsPrefix + "/create"
115- resp , err := c .doRequest (
116- http .MethodPost ,
117- createPath ,
118- bytes .NewReader (jsonData ),
119- )
120- if err != nil {
121- return "" , false , c .handleQueryError (err , createPath )
159+ // First check for specific error types using errors.Is
160+ if errors .Is (err , context .DeadlineExceeded ) ||
161+ errors .Is (err , io .ErrUnexpectedEOF ) ||
162+ errors .Is (err , io .EOF ) ||
163+ errors .Is (err , ErrServiceUnavailable ) {
164+ return true
122165 }
123- defer resp .Body .Close ()
124166
125- if resp .StatusCode != http .StatusOK {
126- body , _ := io .ReadAll (resp .Body )
127- return "" , false , fmt .Errorf ("pulling %s failed with status %s: %s" , model , resp .Status , string (body ))
167+ // Fall back to string matching for network errors that don't have specific types
168+ // This is necessary because many network errors are only available as strings
169+ errStr := err .Error ()
170+ retryablePatterns := []string {
171+ "connection refused" ,
172+ "connection reset" ,
173+ "broken pipe" ,
174+ "timeout" ,
175+ "temporary failure" ,
176+ "no such host" ,
177+ "no route to host" ,
178+ "network is unreachable" ,
179+ "i/o timeout" ,
128180 }
129181
130- // Use Docker-style progress display
131- message , progressShown , err := DisplayProgress ( resp . Body , printer )
132- if err != nil {
133- return "" , progressShown , err
182+ for _ , pattern := range retryablePatterns {
183+ if strings . Contains ( strings . ToLower ( errStr ), pattern ) {
184+ return true
185+ }
134186 }
135187
136- return message , progressShown , nil
188+ return false
189+ }
190+
191+ // withRetries executes an operation with automatic retry logic for transient failures
192+ func (c * Client ) withRetries (
193+ operationName string ,
194+ maxRetries int ,
195+ printer standalone.StatusPrinter ,
196+ operation func (attempt int ) (message string , progressShown bool , err error , shouldRetry bool ),
197+ ) (string , bool , error ) {
198+ var lastErr error
199+ var progressShown bool
200+
201+ for attempt := 0 ; attempt <= maxRetries ; attempt ++ {
202+ if attempt > 0 {
203+ // Calculate exponential backoff: 2^(attempt-1) seconds (1s, 2s, 4s)
204+ backoffDuration := time .Duration (1 << uint (attempt - 1 )) * time .Second
205+ printer .PrintErrf ("Retrying %s (attempt %d/%d) in %v...\n " , operationName , attempt , maxRetries , backoffDuration )
206+ time .Sleep (backoffDuration )
207+ }
208+
209+ message , shown , err , shouldRetry := operation (attempt )
210+ progressShown = progressShown || shown
211+
212+ if err == nil {
213+ return message , progressShown , nil
214+ }
215+
216+ lastErr = err
217+ if ! shouldRetry {
218+ return "" , progressShown , err
219+ }
220+ }
221+
222+ return "" , progressShown , fmt .Errorf ("failed to %s after %d retries: %w" , operationName , maxRetries , lastErr )
137223}
138224
139225func (c * Client ) Push (model string , printer standalone.StatusPrinter ) (string , bool , error ) {
140226 model = normalizeHuggingFaceModelName (model )
141- pushPath := inference .ModelsPrefix + "/" + model + "/push"
142- resp , err := c .doRequest (
143- http .MethodPost ,
144- pushPath ,
145- nil , // Assuming no body is needed for the push request
146- )
147- if err != nil {
148- return "" , false , c .handleQueryError (err , pushPath )
149- }
150- defer resp .Body .Close ()
151227
152- if resp .StatusCode != http .StatusOK {
153- body , _ := io .ReadAll (resp .Body )
154- return "" , false , fmt .Errorf ("pushing %s failed with status %s: %s" , model , resp .Status , string (body ))
155- }
228+ return c .withRetries ("push" , 3 , printer , func (attempt int ) (string , bool , error , bool ) {
229+ pushPath := inference .ModelsPrefix + "/" + model + "/push"
230+ resp , err := c .doRequest (
231+ http .MethodPost ,
232+ pushPath ,
233+ nil , // Assuming no body is needed for the push request
234+ )
235+ if err != nil {
236+ // Only retry on network errors, not on client errors
237+ if isRetryableError (err ) {
238+ return "" , false , c .handleQueryError (err , pushPath ), true
239+ }
240+ return "" , false , c .handleQueryError (err , pushPath ), false
241+ }
242+ // Close response body explicitly at the end of this attempt, not deferred
243+ defer resp .Body .Close ()
156244
157- // Use Docker-style progress display
158- message , progressShown , err := DisplayProgress (resp .Body , printer )
159- if err != nil {
160- return "" , progressShown , err
161- }
245+ if resp .StatusCode != http .StatusOK {
246+ body , _ := io .ReadAll (resp .Body )
247+ err := fmt .Errorf ("pushing %s failed with status %s: %s" , model , resp .Status , string (body ))
248+ // Only retry on server errors (5xx), not client errors (4xx)
249+ shouldRetry := resp .StatusCode >= 500 && resp .StatusCode < 600
250+ return "" , false , err , shouldRetry
251+ }
252+
253+ // Use Docker-style progress display
254+ message , shown , err := DisplayProgress (resp .Body , printer )
255+ if err != nil {
256+ // Retry on progress display errors (likely network interruption)
257+ shouldRetry := isRetryableError (err )
258+ return "" , shown , err , shouldRetry
259+ }
162260
163- return message , progressShown , nil
261+ return message , shown , nil , false
262+ })
164263}
165264
166265func (c * Client ) List () ([]dmrm.Model , error ) {
0 commit comments