Skip to content

TP > 1, PP > 1, sequence packing enabled --> seq length not divisible error #2110

@pengdurice

Description

@pengdurice

Describe the bug

When TP > 1, PP > 1 and sequence packing is enabled, we face the error.

Traceback (most recent call last):
File "nemo_rl/models/megatron/train.py", line 101, in model_forward
output_tensor = model(
File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
return self.module(*inputs, **kwargs)
File "megatron/core/transformer/module.py", line 493, in forward
outputs = self.module(*inputs, **kwargs)
File "megatron/core/models/gpt/gpt_model.py", line 504, in forward
preproc_output = self._preprocess(
File "megatron/core/models/gpt/gpt_model.py", line 313, in _preprocess
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
File "megatron/core/models/common/embeddings/language_model_embedding.py", line 111, in forward
word_embeddings = self.word_embeddings(input_ids)
File "megatron/core/tensor_parallel/layers.py", line 297, in forward
output = reduce_scatter_to_sequence_parallel_region(
File "megatron/core/tensor_parallel/mappings.py", line 518, in reduce_scatter_to_sequence_parallel_region
return _ReduceScatterToSequenceParallelRegion.apply(
File "megatron/core/tensor_parallel/mappings.py", line 365, in forward
return reduce_scatter_along_first_dim(input, group, input_split_sizes, use_global_buffer)
File "megatron/core/tensor_parallel/mappings.py", line 173, in _reduce_scatter_along_first_dim
dim_size[0] % world_size == 0
AssertionError: First dimension of the tensor should be divisible by tensor parallel size

Steps/Code to reproduce bug
Find a yaml file in the recipes that uses megatron and config it to something like:

defaults: ../../sft.yaml
policy:
  model_name: Qwen/Qwen2.5-Math-7B
  train_global_batch_size: 64
  max_total_sequence_length: 3200
  megatron_cfg:
    enabled: true
    tensor_model_parallel_size: 4
    pipeline_model_parallel_size: 2
  make_sequence_length_divisible_by: 8
  sequence_packing:
    enabled: true

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

Expected behavior

We should support PP > 1, TP > 1 and sequence packing with no problem.

Additional context

Working on finding the root cause and proposing a fix (https://github.com/pengdurice/RL/tree/peng-sequence-packing-v1), or let me know if I misconfigured anything. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions