Skip to content

[Bugfix] Resolve Rank index out of range during BWD when sp_size < world_size in Ulysses#7809

Merged
tohtana merged 4 commits intodeepspeedai:masterfrom
Flink-ddd:fix/issue-7672-ulysses-sp-backward-stability
Jan 28, 2026
Merged

[Bugfix] Resolve Rank index out of range during BWD when sp_size < world_size in Ulysses#7809
tohtana merged 4 commits intodeepspeedai:masterfrom
Flink-ddd:fix/issue-7672-ulysses-sp-backward-stability

Conversation

@Flink-ddd
Copy link
Contributor

@Flink-ddd Flink-ddd commented Jan 23, 2026

Description

This PR addresses Issue #7672.

When sequence_parallel_size is smaller than world_size (e.g., sp_size=2 on 4 GPUs) with PyTorch < 2.3, using torch.distributed.nn.functional.all_gather for loss aggregation triggers an IndexError: tuple index out of range during the backward pass. This is due to a known PyTorch issue where the backward hook accesses the global rank instead of the group rank.

Solution

  1. Regression Test & Workaround: Updated the regression test TestUlyssesLossBackward to implement a Weighted All-Reduce pattern.
  • Before: all_gather -> manual sum (Vulnerable to rank indexing mismatch on older PyTorch).
  • After: all_reduce(weighted_loss) / all_reduce(total_weight) (Robust and supports weighted averaging).
  1. Runtime Warning: Added a version check (required_torch_version) in DeepSpeedEngine. It now logs a warning if Sequence Parallelism is enabled on PyTorch < 2.3, providing a link to the workaround test case.
  2. Documentation: Updated ulysses-alst-sequence-parallelism.md with a note regarding legacy PyTorch versions and the recommended workaround.

Verification

Added and verified the regression test tests/unit/sequence_parallelism/test_ulysses.py which now validates the weighted averaging logic.

1. Reproduction (Before Fix)
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
Screenshot 2026-01-23 at 23 53 42

2. Verification (After Fix)
Verified the fix using the regression test logic on 4x RTX A6000. The backward pass now completes successfully on all ranks without error.
Screenshot 2026-01-23 at 23 52 54

Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the fix/issue-7672-ulysses-sp-backward-stability branch from 2b386ab to 4dc7846 Compare January 23, 2026 15:58
@tohtana
Copy link
Collaborator

tohtana commented Jan 23, 2026

@Flink-ddd Thank you for opening this PR! I only see changes in tests. Did you miss committing some changes?

@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , Thanks for the review. There are no missing commits. The issue reported in #7672 stems from using all_gather for loss aggregation in the user's training loop, rather than a bug within DeepSpeed's internal runtime.

Since we cannot patch user scripts directly, I submitted this regression test to:

  1. Verify that the correct approach (using all_reduce) works stably when sp_size < world_size.
  2. Prevent future regressions or confusion regarding this usage pattern.

but, pls tell me if you have other option.

@tohtana
Copy link
Collaborator

tohtana commented Jan 24, 2026

Thank you for your clarification, @Flink-ddd!I It looks like a bug in PyTorch. In AllGather’s backward pass, we should use the local rank within the given process group. It appears this was fixed in v2.3.

  • v2.3: rank = dist.get_rank(group=ctx.group)
  • v2.2: rank = dist.get_rank()

As you said, we can’t force client code to implement loss calculation in a particular way. So I’m wondering whether we should simply add an assertion to check the PyTorch version when SP is enabled. We could also note that SP requires v2.3 or later in the document, even though the DeepSpeed code itself doesn’t have an issue with older versions.

It would still be good to add a regression test. One concern is that the all-reduce approach can’t implement weighted loss averaging, which is used in the original example.

What are your thoughts?

@Flink-ddd
Copy link
Contributor Author

Flink-ddd commented Jan 25, 2026

Hi @tohtana Thanks for the suggestion. I agree that simulating the weighted averaging pattern is better for real-world scenarios. I will update the test case to implement the weighted all-reduce pattern (reducing both the weighted loss and total weights separately) to address this.

@tohtana
Copy link
Collaborator

tohtana commented Jan 25, 2026

Hi @Flink-ddd
Do you think we should support SP with v2.2 or older?

@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , Yes, I believe we should. Many production environments and clusters are still pinned to PyTorch v2.1 or v2.2 due to CUDA driver constraints or stability requirements. maintaining support for SP on these older versions adds significant value to DeepSpeed's compatibility. This regression test ensures that we continue to support these users stably. However, it depends on your perspective. If you think it's unnecessary, we can set assert torch.version >= 2.3 and then turn off this PR.

@tohtana
Copy link
Collaborator

tohtana commented Jan 25, 2026

Okay, then let's keep the older versions. But adding a regression test doesn't prevent the strange error from confusing users. How about these?

  • Update your new regression test for weighted loss averaging
  • Add a note in the tutorial about the allgather issue and a link to your new regression test (or the code snippet of your solution)
  • Show a warning message during initialization when SP is enabled, telling users that all-gather might cause an issue, and include a link. We can also link to the new test case (or the relevant section of the tutorial).

Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , Thanks for your aadvise, That sounds like a solid plan. I agree that adding a warning and updating the documentation will greatly improve the user experience for those on older PyTorch versions. I have already pushed these update.

Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update! I left one more comment. Once it is fixed and tests pass, let's merge this.

Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , ready for review again, Thank you for suggestion and help.

@tohtana tohtana enabled auto-merge (squash) January 28, 2026 07:19
@Flink-ddd
Copy link
Contributor Author

Hi @tohtana , all CI tests are successful now. Could you please help approve it when you are free? Thanks.

@Flink-ddd Flink-ddd requested a review from tohtana January 28, 2026 10:25
@tohtana tohtana merged commit bb250a2 into deepspeedai:master Jan 28, 2026
13 checks passed
@tohtana
Copy link
Collaborator

tohtana commented Jan 28, 2026

Merged. Thank you for your contribution, @Flink-ddd!

phalani-paladugu pushed a commit to phalani-paladugu/DeepSpeed that referenced this pull request Jan 29, 2026
…rld_size in Ulysses (deepspeedai#7809)

### Description
This PR addresses Issue deepspeedai#7672.

When sequence_parallel_size is smaller than world_size (e.g., sp_size=2
on 4 GPUs) with PyTorch < 2.3, using
torch.distributed.nn.functional.all_gather for loss aggregation triggers
an IndexError: tuple index out of range during the backward pass. This
is due to a known PyTorch issue where the backward hook accesses the
global rank instead of the group rank.

### Solution
1. Regression Test & Workaround: Updated the regression test
TestUlyssesLossBackward to implement a Weighted All-Reduce pattern.
- Before: all_gather -> manual sum (Vulnerable to rank indexing mismatch
on older PyTorch).
- After: all_reduce(weighted_loss) / all_reduce(total_weight) (Robust
and supports weighted averaging).
2. Runtime Warning: Added a version check (required_torch_version) in
DeepSpeedEngine. It now logs a warning if Sequence Parallelism is
enabled on PyTorch < 2.3, providing a link to the workaround test case.
3. Documentation: Updated ulysses-alst-sequence-parallelism.md with a
note regarding legacy PyTorch versions and the recommended workaround.

### Verification
Added and verified the regression test
tests/unit/sequence_parallelism/test_ulysses.py which now validates the
weighted averaging logic.

**1. Reproduction (Before Fix)**
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
<img width="1370" height="860" alt="Screenshot 2026-01-23 at 23 53 42"
src="https://github.com/user-attachments/assets/f4005c02-ff6c-46ea-a1a7-caac2093128b"
/>

**2. Verification (After Fix)**
Verified the fix using the regression test logic on 4x RTX A6000. The
backward pass now completes successfully on all ranks without error.
<img width="1192" height="605" alt="Screenshot 2026-01-23 at 23 52 54"
src="https://github.com/user-attachments/assets/c14cd093-67b7-42b0-ae15-65555c129082"
/>

---------

Signed-off-by: vensen <vensenmu@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Signed-off-by: Phalani Paladugu <mailofphalani@gmail.com>
ksugama pushed a commit to ksugama/DeepSpeed that referenced this pull request Feb 9, 2026
…rld_size in Ulysses (deepspeedai#7809)

### Description
This PR addresses Issue deepspeedai#7672.

When sequence_parallel_size is smaller than world_size (e.g., sp_size=2
on 4 GPUs) with PyTorch < 2.3, using
torch.distributed.nn.functional.all_gather for loss aggregation triggers
an IndexError: tuple index out of range during the backward pass. This
is due to a known PyTorch issue where the backward hook accesses the
global rank instead of the group rank.

### Solution
1. Regression Test & Workaround: Updated the regression test
TestUlyssesLossBackward to implement a Weighted All-Reduce pattern.
- Before: all_gather -> manual sum (Vulnerable to rank indexing mismatch
on older PyTorch).
- After: all_reduce(weighted_loss) / all_reduce(total_weight) (Robust
and supports weighted averaging).
2. Runtime Warning: Added a version check (required_torch_version) in
DeepSpeedEngine. It now logs a warning if Sequence Parallelism is
enabled on PyTorch < 2.3, providing a link to the workaround test case.
3. Documentation: Updated ulysses-alst-sequence-parallelism.md with a
note regarding legacy PyTorch versions and the recommended workaround.

### Verification
Added and verified the regression test
tests/unit/sequence_parallelism/test_ulysses.py which now validates the
weighted averaging logic.

**1. Reproduction (Before Fix)**
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
<img width="1370" height="860" alt="Screenshot 2026-01-23 at 23 53 42"
src="https://github.com/user-attachments/assets/f4005c02-ff6c-46ea-a1a7-caac2093128b"
/>

**2. Verification (After Fix)**
Verified the fix using the regression test logic on 4x RTX A6000. The
backward pass now completes successfully on all ranks without error.
<img width="1192" height="605" alt="Screenshot 2026-01-23 at 23 52 54"
src="https://github.com/user-attachments/assets/c14cd093-67b7-42b0-ae15-65555c129082"
/>

---------

Signed-off-by: vensen <vensenmu@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants