Skip to content

Commit 70b55b0

Browse files
authored
finetuning batch split (#994)
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 761dc73 commit 70b55b0

File tree

4 files changed

+352
-72
lines changed

4 files changed

+352
-72
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Finetuning-specific data handling utilities."""
16+
17+
from typing import Any, Iterator
18+
19+
import torch
20+
21+
22+
def split_batch_into_microbatches(
23+
batch: dict[str, Any], num_microbatches: int, enforce_divisible: bool = True
24+
) -> list[dict[str, Any]]:
25+
"""Split a batch dictionary into microbatches.
26+
27+
Takes a global batch (e.g., [16, 240] for tokens) and splits it into
28+
num_microbatches smaller batches (e.g., 4 batches of [4, 240]).
29+
30+
Args:
31+
batch: Dictionary containing tensors with batch_size = num_microbatches * micro_batch_size
32+
num_microbatches: Number of microbatches to split into
33+
enforce_divisible: Whether to enforce batch_size % num_microbatches == 0
34+
35+
Returns:
36+
List of microbatch dictionaries, each containing the same keys as the input batch
37+
38+
Example:
39+
>>> batch = {'tokens': torch.rand(16, 240), 'labels': torch.rand(16, 240)}
40+
>>> microbatches = split_batch_into_microbatches(batch, num_microbatches=4)
41+
>>> len(microbatches) # 4
42+
>>> microbatches[0]['tokens'].shape # torch.Size([4, 240])
43+
"""
44+
# Identify tensor items vs other items (like metadata)
45+
tensor_items = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}
46+
other_items = {k: v for k, v in batch.items() if not isinstance(v, torch.Tensor)}
47+
48+
if len(tensor_items) == 0:
49+
raise ValueError("Batch must contain at least one tensor")
50+
51+
# Get batch size from first tensor
52+
first_key = next(iter(tensor_items.keys()))
53+
batch_size = tensor_items[first_key].shape[0]
54+
55+
if enforce_divisible and batch_size % num_microbatches != 0:
56+
raise ValueError(
57+
f"Batch size {batch_size} is not divisible by num_microbatches {num_microbatches}. "
58+
f"Cannot split evenly into microbatches."
59+
)
60+
61+
# Split all tensors along batch dimension (dim=0)
62+
split_tensors = {}
63+
for key, tensor in tensor_items.items():
64+
split_tensors[key] = torch.tensor_split(tensor, num_microbatches, dim=0)
65+
66+
# Create microbatch dictionaries
67+
microbatches = []
68+
for i in range(num_microbatches):
69+
microbatch = {}
70+
71+
# Add split tensors
72+
for key, splits in split_tensors.items():
73+
microbatch[key] = splits[i]
74+
75+
# Handle non-tensor items (metadata, etc.)
76+
for key, value in other_items.items():
77+
if isinstance(value, list) and len(value) == batch_size:
78+
# If it's a list with length matching batch size, split it too
79+
micro_batch_size = batch_size // num_microbatches
80+
start_idx = i * micro_batch_size
81+
end_idx = start_idx + micro_batch_size
82+
microbatch[key] = value[start_idx:end_idx]
83+
else:
84+
# Otherwise copy as-is (e.g., global metadata)
85+
microbatch[key] = value
86+
87+
microbatches.append(microbatch)
88+
89+
return microbatches
90+
91+
92+
def prepare_finetuning_batch(
93+
data_iterator: Iterator,
94+
num_microbatches: int,
95+
default_seq_length: int,
96+
seq_key: str = "tokens",
97+
) -> tuple[Iterator, int]:
98+
"""Prepare a finetuning batch by getting global batch and splitting into microbatches.
99+
100+
This function handles the finetuning-specific data flow:
101+
1. Gets the full global batch from the iterator
102+
2. Extracts the dynamic sequence length from the batch
103+
3. Splits the batch into microbatches with consistent sequence length
104+
4. Returns an iterator over microbatches and the extracted sequence length
105+
106+
Args:
107+
data_iterator: Iterator that yields global batches (e.g., from DataLoader with batch sampler)
108+
num_microbatches: Number of microbatches to split each global batch into
109+
default_seq_length: Fallback sequence length if it cannot be extracted from batch
110+
seq_key: Key in batch dict containing the sequence tensor (default: 'tokens')
111+
112+
Returns:
113+
Tuple of:
114+
- Iterator over microbatches (each microbatch is a dict with same keys as global batch)
115+
- Sequence length extracted from the global batch (or default_seq_length if not found)
116+
117+
Example:
118+
>>> # DataLoader yields global batch of shape [16, 240]
119+
>>> microbatch_iter, seq_len = prepare_finetuning_batch(
120+
... data_iterator=iter(dataloader),
121+
... num_microbatches=4,
122+
... default_seq_length=2048
123+
... )
124+
>>> seq_len # 240 (extracted from batch)
125+
>>> batch1 = next(microbatch_iter)
126+
>>> batch1['tokens'].shape # torch.Size([4, 240])
127+
"""
128+
# Get full global batch from dataloader
129+
global_batch = next(data_iterator)
130+
131+
# Extract dynamic seq_length from the full batch
132+
seq_length = default_seq_length
133+
if seq_key in global_batch and isinstance(global_batch[seq_key], torch.Tensor):
134+
seq_length = global_batch[seq_key].size(1)
135+
136+
# Split into microbatches
137+
microbatches = split_batch_into_microbatches(global_batch, num_microbatches)
138+
139+
# Return iterator over microbatches and the extracted seq_length
140+
return iter(microbatches), seq_length

src/megatron/bridge/data/samplers.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -250,29 +250,32 @@ def __init__(
250250
self._global_batch_size_on_this_data_parallel_rank = self._num_micro_batches * self.micro_batch_size
251251

252252
def __len__(self) -> int:
253-
"""Return the number of microbatches this sampler will yield.
253+
"""Return the number of batches this sampler will yield.
254254
255-
Since we yield one microbatch per global batch × num_micro_batches,
256-
multiply by num_micro_batches to get the total number of yields.
255+
Since we now yield the full global batch at once (not split into microbatches),
256+
this returns the number of global batches.
257257
"""
258258
num_available_samples = self.total_samples - self.consumed_samples % self.total_samples
259259
if self.drop_last:
260260
num_global_batches = num_available_samples // self._global_batch_size
261261
else:
262262
num_global_batches = (num_available_samples + self._global_batch_size - 1) // self._global_batch_size
263263

264-
# Each global batch yields num_micro_batches microbatches
265-
return num_global_batches * self._num_micro_batches
264+
# Each call to __iter__ yields one global batch
265+
return num_global_batches
266266

267267
def __iter__(self) -> Iterator[list[int]]:
268-
"""Yields lists of indices for each microbatch assigned to this rank.
268+
"""Yields lists of indices for the full global batch assigned to this rank.
269269
270270
Accumulates a full global batch, then distributes indices in interleaved fashion
271-
to data parallel ranks, yielding one microbatch at a time for megatron-core compatibility.
272-
273-
This ensures all samples in a global batch can be padded to the same max length
274-
(important for variable-length finetuning) while being compatible with megatron-core's
275-
microbatch loop that calls next() multiple times per training step.
271+
to data parallel ranks, yielding ALL indices for this rank at once. This allows
272+
the DataLoader's collate_fn to receive the full global batch and determine optimal
273+
padding across all samples before the training loop splits into microbatches.
274+
275+
This is essential for variable-length finetuning where we need to:
276+
1. Compute max_length across the entire global batch
277+
2. Pad all samples to the same length
278+
3. Then split into microbatches with consistent sequence length
276279
"""
277280
batch = []
278281
# Last batch will be dropped if drop_last is True
@@ -290,11 +293,9 @@ def __iter__(self) -> Iterator[list[int]]:
290293
]
291294
assert len(all_indices) == self._global_batch_size_on_this_data_parallel_rank
292295

293-
# Yield one microbatch at a time
294-
for microbatch_idx in range(self._num_micro_batches):
295-
start = microbatch_idx * self.micro_batch_size
296-
end = start + self.micro_batch_size
297-
yield all_indices[start:end]
296+
# Yield ALL indices at once (not split into microbatches)
297+
# The training loop will handle splitting after collation
298+
yield all_indices
298299

299300
batch = []
300301

@@ -306,12 +307,8 @@ def __iter__(self) -> Iterator[list[int]]:
306307
num_pad = self._global_batch_size // self.data_parallel_size - len(all_indices)
307308
all_indices = all_indices + [-1] * num_pad
308309

309-
# Yield one microbatch at a time
310-
for microbatch_idx in range(self._num_micro_batches):
311-
start = microbatch_idx * self.micro_batch_size
312-
end = start + self.micro_batch_size
313-
if start < len(all_indices):
314-
yield all_indices[start:end]
310+
# Yield ALL indices at once
311+
yield all_indices
315312

316313

317314
class RandomSeedDataset(Dataset):

src/megatron/bridge/training/train.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,31 @@ def train_step(
522522
overlap_param_gather=cfg.ddp.overlap_param_gather,
523523
)
524524

525+
# Handle finetuning vs pretraining data consumption
526+
seq_length = model_config.seq_length # Default for pretraining
527+
forward_backward_data_iterator = data_iterator # Default for pretraining
528+
529+
if cfg.dataset.dataloader_type == "batch":
530+
# Finetuning path to support variable-length sequences
531+
from megatron.bridge.data.finetuning import prepare_finetuning_batch
532+
533+
forward_backward_data_iterator, seq_length = prepare_finetuning_batch(
534+
data_iterator=data_iterator,
535+
num_microbatches=get_num_microbatches(),
536+
default_seq_length=model_config.seq_length,
537+
seq_key="tokens",
538+
)
539+
525540
# Forward pass.
526541
forward_backward_func = get_forward_backward_func()
527542
losses_reduced = forward_backward_func(
528543
forward_step_func=forward_step_func,
529-
data_iterator=data_iterator,
544+
data_iterator=forward_backward_data_iterator,
530545
model=model,
531546
num_microbatches=get_num_microbatches(),
532-
seq_length=model_config.seq_length,
547+
seq_length=seq_length,
533548
micro_batch_size=train_config.micro_batch_size,
534-
decoder_seq_length=model_config.seq_length,
549+
decoder_seq_length=seq_length,
535550
forward_only=False,
536551
)
537552
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
@@ -1074,14 +1089,20 @@ def _dummy_train_step(
10741089
global_state: Global state containing configuration
10751090
train_data_iterator: Iterator over training data
10761091
"""
1092+
cfg = global_state.cfg
10771093
num_microbatches = get_num_microbatches()
10781094
rerun_state_machine = get_rerun_state_machine()
10791095

10801096
while rerun_state_machine.should_run_forward_backward(train_data_iterator):
1081-
for _ in range(num_microbatches):
1082-
if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
1083-
if train_data_iterator is not None:
1097+
if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
1098+
if train_data_iterator is not None:
1099+
if cfg.dataset.dataloader_type == "batch":
1100+
# Finetuning: Consume global batch once
10841101
_ = next(train_data_iterator)
1102+
else:
1103+
# Pretrain: Consume microbatches one at a time
1104+
for _ in range(num_microbatches):
1105+
_ = next(train_data_iterator)
10851106

10861107

10871108
def _handle_mxfp8_param_buffer_copy(

0 commit comments

Comments
 (0)