33import json
44import subprocess
55import sys
6+ from concurrent .futures import ThreadPoolExecutor , as_completed
67
78import requests
89
@@ -38,6 +39,18 @@ def _batch_request(oids_sizes: list[tuple[str, int]], operation: str, url: str |
3839 return resp .json ()
3940
4041
42+ def _download_one (oid : str , href : str , dl_headers : dict ) -> str | None :
43+ """Download a single object. Returns oid on success, None on failure."""
44+ try :
45+ resp = requests .get (href , headers = dl_headers , timeout = 120 )
46+ resp .raise_for_status ()
47+ store_object (oid , resp .content )
48+ return oid
49+ except Exception as e :
50+ print (f" error downloading { oid [:12 ]} : { e } " , file = sys .stderr )
51+ return None
52+
53+
4154def download_objects (oids_sizes : list [tuple [str , int ]], progress : bool = True ) -> int :
4255 """Download LFS objects that aren't in the local cache. Returns count downloaded."""
4356 # filter out already-cached objects
@@ -48,7 +61,8 @@ def download_objects(oids_sizes: list[tuple[str, int]], progress: bool = True) -
4861 downloaded = 0
4962 total = len (needed )
5063
51- # process in batches
64+ # collect all download URLs via batch API, then fetch in parallel
65+ to_fetch : list [tuple [str , str , dict ]] = [] # (oid, href, headers)
5266 for i in range (0 , len (needed ), BATCH_SIZE ):
5367 batch = needed [i :i + BATCH_SIZE ]
5468 result = _batch_request (batch , "download" )
@@ -61,19 +75,17 @@ def download_objects(oids_sizes: list[tuple[str, int]], progress: bool = True) -
6175 actions = obj .get ("actions" , {})
6276 dl = actions .get ("download" )
6377 if dl is None :
64- # server says we already have it or it doesn't exist
6578 continue
66-
67- href = dl ["href" ]
68- dl_headers = dl .get ("header" , {})
69- resp = requests .get (href , headers = dl_headers , timeout = 120 )
70- resp .raise_for_status ()
71- data = resp .content
72-
73- store_object (oid , data )
74- downloaded += 1
75- if progress :
76- print (f"\r downloading: { downloaded } /{ total } " , end = "" , flush = True )
79+ to_fetch .append ((oid , dl ["href" ], dl .get ("header" , {})))
80+
81+ # parallel download
82+ with ThreadPoolExecutor (max_workers = DOWNLOAD_WORKERS ) as pool :
83+ futures = {pool .submit (_download_one , oid , href , hdrs ): oid for oid , href , hdrs in to_fetch }
84+ for future in as_completed (futures ):
85+ if future .result () is not None :
86+ downloaded += 1
87+ if progress :
88+ print (f"\r downloading: { downloaded } /{ total } " , end = "" , flush = True )
7789
7890 if progress and downloaded > 0 :
7991 print ()
0 commit comments