Skip to content

Commit 45a8a31

Browse files
committed
update estimator
1 parent ac74931 commit 45a8a31

File tree

4 files changed

+67
-17
lines changed

4 files changed

+67
-17
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,9 @@ compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g.
760760
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
761761
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
762762

763+
# MaxText Estimator configs
764+
write_estimator_result: False
765+
763766
decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature)
764767
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
765768
decode_sampling_top_k: 0 # set if you're doing top-k

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,7 @@ class AOT(BaseModel):
11781178
compiled_trainstep_file: PathStr = Field("", description="Name of saved serialized compiled train_step.")
11791179
compile_topology: str = Field("", description="Target hardware version, e.g. 'v5e-256'.")
11801180
compile_topology_num_slices: int = Field(-1, description="Number of target slices.")
1181+
write_estimator_result: bool = Field(False, description="Write estimator.py results in a separate file.")
11811182

11821183

11831184
class DevelopmentAndDebugging(BaseModel):

src/MaxText/estimator.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,34 @@ def next_policy(policy: dict) -> dict[str, str] | None:
164164
return None
165165

166166

167-
def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
167+
def find_pdb_scalar(config):
168+
"""Calculates the scaling factor to normalize the Per-Device Batch (PDB) size.
169+
170+
In distributed training, the batch size is divided across various mesh axes.
171+
When using non-batch-based sharding (like Tensor Parallelism), the raw
172+
per-device batch size can become a fractional value.
173+
174+
This function identifies those non-batch axes (e.g., 'tensor') and calculates
175+
a multiplier. This scalar represents the value by which a fractional per-device
176+
batch size must be multiplied to result in an integer value, ensuring
177+
compatibility with memory and compute estimation logic.
178+
179+
Args:
180+
config: The configuration object containing 'mesh_axes' and the
181+
corresponding 'ici_{axis}_parallelism' values.
182+
183+
Returns:
184+
float: The aggregate parallelism degree of all non-data/non-FSDP axes,
185+
serving as the integer-normalization constant for the PDB.
186+
"""
187+
pdb_scalar = 1.0
188+
for mesh_axis in config.mesh_axes:
189+
if mesh_axis not in ("data", "fsdp", "fsdp_transpose", "expert", "stage"):
190+
pdb_scalar *= getattr(config, f"ici_{mesh_axis}_parallelism")
191+
return pdb_scalar
192+
193+
194+
def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64, pdb_scalar=1.0) -> int:
168195
"""
169196
Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error.
170197
@@ -181,26 +208,29 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
181208
"""
182209
print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.")
183210

211+
if pdb_scalar == 0.0:
212+
raise ValueError("pdb_scalar cannot be value zero.")
213+
184214
if is_oom(base_argv, policy, min_pdb):
185215
print(f"OOM at minimum batch size {min_pdb}.")
186-
return min_pdb - 1
216+
return min_pdb - 1 / pdb_scalar
187217
if not is_oom(base_argv, policy, max_pdb):
188218
print(f"No OOM at maximum batch size {max_pdb}.")
189219
return max_pdb
190220

191-
low, high, result = min_pdb, max_pdb, min_pdb
221+
low, high, result = min_pdb * pdb_scalar, max_pdb * pdb_scalar, min_pdb * pdb_scalar
192222
while low <= high:
193223
mid = (low + high) // 2
194224
if mid < min_pdb:
195225
low = mid + 1
196226
continue
197227

198-
if not is_oom(base_argv, policy, mid):
228+
if not is_oom(base_argv, policy, mid / pdb_scalar):
199229
result = mid
200230
low = mid + 1
201231
else:
202232
high = mid - 1
203-
return result
233+
return result / pdb_scalar
204234

205235

206236
def is_oom(base_argv, policy: dict, pdb: int) -> bool:
@@ -294,6 +324,7 @@ def search(
294324
base_argv,
295325
init_policy: dict = None,
296326
max_pdb: int = 256,
327+
pdb_scalar: float = 1.0,
297328
) -> list[tuple[int, dict]]:
298329
"""
299330
Performs the core search algorithm to find the Pareto frontier points.
@@ -308,11 +339,13 @@ def search(
308339
"""
309340
output_lst = []
310341
policy = build_full_device_policy(tensor_names) if init_policy is None else init_policy
311-
pdb = 1
342+
pdb = 1 / pdb_scalar
312343
while policy is not None:
313-
pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb)
344+
pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb, pdb_scalar=pdb_scalar)
314345
if pdb > 0:
315346
output_lst.append((pdb, policy))
347+
else:
348+
break
316349
policy = next_policy(policy)
317350
return output_lst
318351

@@ -432,6 +465,7 @@ def main(argv_list: Sequence[str]) -> None:
432465
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
433466
config = pyconfig.initialize(base_argv)
434467
train_compile.validate_config(config)
468+
pdb_scalar = find_pdb_scalar(config)
435469

436470
# Get the prioritized list of tensors to try rematerializing
437471
tensor_names = generate_priority_list(config, provided_tensor_names)
@@ -451,25 +485,36 @@ def main(argv_list: Sequence[str]) -> None:
451485
# MODE 2: No batch size. Search for both batch size and policy.
452486
print("No batch size provided. Searching for max batch size and policies...")
453487
# First, find the absolute max batch size that fits *even with full remat*
454-
max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1)
488+
max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1 / pdb_scalar, pdb_scalar=pdb_scalar)
489+
suggested_list = [(max_pdb, full_remat_policy)]
455490

456491
# Now, search for combinations, starting from no-remat up to max_pdb
457-
suggested_list = search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb)
492+
suggested_list.extend(
493+
search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb, pdb_scalar=pdb_scalar)
494+
)
458495

459496
end_time = time.time()
460497
print(f"\nSearch completed in {end_time - start_time:.2f} seconds.")
461498

462499
output_filename = "remat_commands_from_estimator.txt"
463-
print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...")
464500

465-
with open(output_filename, "w", encoding="utf-8") as f:
466-
for pdb_result, policy_result in suggested_list:
467-
# Build the full, runnable command string
468-
final_argv = build_argv(base_argv[1:], policy_result, pdb_result)
469-
command = "python -m MaxText.train " + " ".join(final_argv)
501+
# Only open the file and print the status if the config allows writing
502+
if config.write_estimator_result:
503+
print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...")
470504

471-
f.write(command + "\n")
472-
print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}")
505+
with open(output_filename, "w", encoding="utf-8") as f:
506+
for pdb_result, policy_result in suggested_list:
507+
# Build the full, runnable command string
508+
final_argv = build_argv(base_argv[1:], policy_result, pdb_result)
509+
command = "python -m MaxText.train " + " ".join(final_argv)
510+
511+
f.write(command + "\n")
512+
print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}")
513+
514+
print("Done.")
515+
else:
516+
for pdb_result, policy_result in suggested_list:
517+
print(f" - Found valid combo (not saved to file): pdb={pdb_result}, policy={policy_result}")
473518

474519
print("Done.")
475520

src/MaxText/train_compile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def is_oom(argv: Sequence[str]) -> bool:
186186
except Exception as e:
187187
# return true if OOM error happens
188188
# OOM error looks like
189+
# Check failed: entries[i] <= std::numeric_limits<uint32_t>::max()
189190
# jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ...
190191
# jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ...
191192
message = str(e).lower()

0 commit comments

Comments
 (0)