@@ -164,7 +164,34 @@ def next_policy(policy: dict) -> dict[str, str] | None:
164164 return None
165165
166166
167- def largest_batch_size (base_argv , policy , min_pdb , max_pdb = 64 ) -> int :
167+ def find_pdb_scalar (config ):
168+ """Calculates the scaling factor to normalize the Per-Device Batch (PDB) size.
169+
170+ In distributed training, the batch size is divided across various mesh axes.
171+ When using non-batch-based sharding (like Tensor Parallelism), the raw
172+ per-device batch size can become a fractional value.
173+
174+ This function identifies those non-batch axes (e.g., 'tensor') and calculates
175+ a multiplier. This scalar represents the value by which a fractional per-device
176+ batch size must be multiplied to result in an integer value, ensuring
177+ compatibility with memory and compute estimation logic.
178+
179+ Args:
180+ config: The configuration object containing 'mesh_axes' and the
181+ corresponding 'ici_{axis}_parallelism' values.
182+
183+ Returns:
184+ float: The aggregate parallelism degree of all non-data/non-FSDP axes,
185+ serving as the integer-normalization constant for the PDB.
186+ """
187+ pdb_scalar = 1.0
188+ for mesh_axis in config .mesh_axes :
189+ if mesh_axis not in ("data" , "fsdp" , "fsdp_transpose" , "expert" , "stage" ):
190+ pdb_scalar *= getattr (config , f"ici_{ mesh_axis } _parallelism" )
191+ return pdb_scalar
192+
193+
194+ def largest_batch_size (base_argv , policy , min_pdb , max_pdb = 64 , pdb_scalar = 1.0 ) -> int :
168195 """
169196 Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error.
170197
@@ -181,26 +208,29 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
181208 """
182209 print (f"Starting binary search for the largest batch size between { min_pdb } and { max_pdb } ." )
183210
211+ if pdb_scalar == 0.0 :
212+ raise ValueError ("pdb_scalar cannot be value zero." )
213+
184214 if is_oom (base_argv , policy , min_pdb ):
185215 print (f"OOM at minimum batch size { min_pdb } ." )
186- return min_pdb - 1
216+ return min_pdb - 1 / pdb_scalar
187217 if not is_oom (base_argv , policy , max_pdb ):
188218 print (f"No OOM at maximum batch size { max_pdb } ." )
189219 return max_pdb
190220
191- low , high , result = min_pdb , max_pdb , min_pdb
221+ low , high , result = min_pdb * pdb_scalar , max_pdb * pdb_scalar , min_pdb * pdb_scalar
192222 while low <= high :
193223 mid = (low + high ) // 2
194224 if mid < min_pdb :
195225 low = mid + 1
196226 continue
197227
198- if not is_oom (base_argv , policy , mid ):
228+ if not is_oom (base_argv , policy , mid / pdb_scalar ):
199229 result = mid
200230 low = mid + 1
201231 else :
202232 high = mid - 1
203- return result
233+ return result / pdb_scalar
204234
205235
206236def is_oom (base_argv , policy : dict , pdb : int ) -> bool :
@@ -294,6 +324,7 @@ def search(
294324 base_argv ,
295325 init_policy : dict = None ,
296326 max_pdb : int = 256 ,
327+ pdb_scalar : float = 1.0 ,
297328) -> list [tuple [int , dict ]]:
298329 """
299330 Performs the core search algorithm to find the Pareto frontier points.
@@ -308,11 +339,13 @@ def search(
308339 """
309340 output_lst = []
310341 policy = build_full_device_policy (tensor_names ) if init_policy is None else init_policy
311- pdb = 1
342+ pdb = 1 / pdb_scalar
312343 while policy is not None :
313- pdb = largest_batch_size (base_argv , policy , min_pdb = pdb , max_pdb = max_pdb )
344+ pdb = largest_batch_size (base_argv , policy , min_pdb = pdb , max_pdb = max_pdb , pdb_scalar = pdb_scalar )
314345 if pdb > 0 :
315346 output_lst .append ((pdb , policy ))
347+ else :
348+ break
316349 policy = next_policy (policy )
317350 return output_lst
318351
@@ -432,6 +465,7 @@ def main(argv_list: Sequence[str]) -> None:
432465 with contextlib .redirect_stdout (devnull ), contextlib .redirect_stderr (devnull ):
433466 config = pyconfig .initialize (base_argv )
434467 train_compile .validate_config (config )
468+ pdb_scalar = find_pdb_scalar (config )
435469
436470 # Get the prioritized list of tensors to try rematerializing
437471 tensor_names = generate_priority_list (config , provided_tensor_names )
@@ -451,25 +485,36 @@ def main(argv_list: Sequence[str]) -> None:
451485 # MODE 2: No batch size. Search for both batch size and policy.
452486 print ("No batch size provided. Searching for max batch size and policies..." )
453487 # First, find the absolute max batch size that fits *even with full remat*
454- max_pdb = largest_batch_size (base_argv , full_remat_policy , min_pdb = 1 )
488+ max_pdb = largest_batch_size (base_argv , full_remat_policy , min_pdb = 1 / pdb_scalar , pdb_scalar = pdb_scalar )
489+ suggested_list = [(max_pdb , full_remat_policy )]
455490
456491 # Now, search for combinations, starting from no-remat up to max_pdb
457- suggested_list = search (tensor_names , base_argv , init_policy = full_device_policy , max_pdb = max_pdb )
492+ suggested_list .extend (
493+ search (tensor_names , base_argv , init_policy = full_device_policy , max_pdb = max_pdb , pdb_scalar = pdb_scalar )
494+ )
458495
459496 end_time = time .time ()
460497 print (f"\n Search completed in { end_time - start_time :.2f} seconds." )
461498
462499 output_filename = "remat_commands_from_estimator.txt"
463- print (f"Writing { len (suggested_list )} suggested command(s) to { output_filename } ..." )
464500
465- with open (output_filename , "w" , encoding = "utf-8" ) as f :
466- for pdb_result , policy_result in suggested_list :
467- # Build the full, runnable command string
468- final_argv = build_argv (base_argv [1 :], policy_result , pdb_result )
469- command = "python -m MaxText.train " + " " .join (final_argv )
501+ # Only open the file and print the status if the config allows writing
502+ if config .write_estimator_result :
503+ print (f"Writing { len (suggested_list )} suggested command(s) to { output_filename } ..." )
470504
471- f .write (command + "\n " )
472- print (f" - Found valid combo: pdb={ pdb_result } , policy={ policy_result } " )
505+ with open (output_filename , "w" , encoding = "utf-8" ) as f :
506+ for pdb_result , policy_result in suggested_list :
507+ # Build the full, runnable command string
508+ final_argv = build_argv (base_argv [1 :], policy_result , pdb_result )
509+ command = "python -m MaxText.train " + " " .join (final_argv )
510+
511+ f .write (command + "\n " )
512+ print (f" - Found valid combo: pdb={ pdb_result } , policy={ policy_result } " )
513+
514+ print ("Done." )
515+ else :
516+ for pdb_result , policy_result in suggested_list :
517+ print (f" - Found valid combo (not saved to file): pdb={ pdb_result } , policy={ policy_result } " )
473518
474519 print ("Done." )
475520
0 commit comments