-
Notifications
You must be signed in to change notification settings - Fork 281
Description
Describe the bug
When TP > 1, PP > 1 and sequence packing is enabled, we face the error.
Traceback (most recent call last):
File "nemo_rl/models/megatron/train.py", line 101, in model_forward
output_tensor = model(
File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
return self.module(*inputs, **kwargs)
File "megatron/core/transformer/module.py", line 493, in forward
outputs = self.module(*inputs, **kwargs)
File "megatron/core/models/gpt/gpt_model.py", line 504, in forward
preproc_output = self._preprocess(
File "megatron/core/models/gpt/gpt_model.py", line 313, in _preprocess
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
File "megatron/core/models/common/embeddings/language_model_embedding.py", line 111, in forward
word_embeddings = self.word_embeddings(input_ids)
File "megatron/core/tensor_parallel/layers.py", line 297, in forward
output = reduce_scatter_to_sequence_parallel_region(
File "megatron/core/tensor_parallel/mappings.py", line 518, in reduce_scatter_to_sequence_parallel_region
return _ReduceScatterToSequenceParallelRegion.apply(
File "megatron/core/tensor_parallel/mappings.py", line 365, in forward
return reduce_scatter_along_first_dim(input, group, input_split_sizes, use_global_buffer)
File "megatron/core/tensor_parallel/mappings.py", line 173, in _reduce_scatter_along_first_dim
dim_size[0] % world_size == 0
AssertionError: First dimension of the tensor should be divisible by tensor parallel size
Steps/Code to reproduce bug
Find a yaml file in the recipes that uses megatron and config it to something like:
defaults: ../../sft.yaml
policy:
model_name: Qwen/Qwen2.5-Math-7B
train_global_batch_size: 64
max_total_sequence_length: 3200
megatron_cfg:
enabled: true
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 2
make_sequence_length_divisible_by: 8
sequence_packing:
enabled: trueA helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
Expected behavior
We should support PP > 1, TP > 1 and sequence packing with no problem.
Additional context
Working on finding the root cause and proposing a fix (https://github.com/pengdurice/RL/tree/peng-sequence-packing-v1), or let me know if I misconfigured anything. Thanks!