Skip to content

fix sp for HunyuanVideo-1.5#656

Open
cszhouY wants to merge 3 commits intoxdit-project:mainfrom
cszhouY:fix-sp
Open

fix sp for HunyuanVideo-1.5#656
cszhouY wants to merge 3 commits intoxdit-project:mainfrom
cszhouY:fix-sp

Conversation

@cszhouY
Copy link

@cszhouY cszhouY commented Mar 3, 2026

What?

Fix the issue #655 , primarily by modifying the SP in transformers_hunyuan_video15.py.

How?

In the original implementation, both encoder_hidden_states and encoder_attention_mask were processed by _chunk_and_pad_sequence. However, USP does not support the argument attention_mask, which caused the padded (masked) portions in encoder_attention_mask to incorrectly participate in the attention computation.

Referencing SP in transformers_hunyuan_video.py, I fixed this issue: first truncating the padded portions from encoder_hidden_states, and using USP with joint_strategy when the sequence length of encoder_hidden_states is not divisible by ulysses_degree.

Test

The test code is the same as the one in the #655 .

width height frames steps seed
1280 720 81 50 42

prompt: A cat holding a paper with words "Hello, world!"

  • Diffusers
hunyuan_video_15_base.mp4
  • xDiT
hunyuan_video_15_t2v_result_ulysses4_ring1_tc_False_720x1280.mp4

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue in the sequence parallelism implementation for the HunyuanVideo-1.5 model, specifically concerning how attention masks and encoder hidden states are managed during distributed processing. The changes ensure that padded tokens do not interfere with attention calculations and streamline the sequence parallel operations by leveraging direct PyTorch and distributed group functionalities, ultimately fixing issue #655.

Highlights

  • Corrected Sequence Parallelism (SP) for Attention Masks: Addressed a bug where USP (Ulysses Sequence Parallelism) incorrectly processed padded portions of encoder_attention_mask, leading to erroneous attention computations.
  • Refined Encoder Hidden States Handling: Modified the processing of encoder_hidden_states to truncate padded regions based on the attention mask before passing them to USP, ensuring only relevant tokens participate in attention.
  • Conditional USP Strategy: Implemented a conditional strategy for USP to use joint_strategy="rear" when the sequence length of encoder_hidden_states is not divisible by the sequence parallel world size, improving flexibility and correctness for text embeddings.
  • Simplified SP Utility Methods: Removed custom _chunk_and_pad_sequence and _gather_and_unpad methods, opting for direct torch.chunk and get_sp_group().all_gather operations for sequence parallelism.
Changelog
  • xfuser/model_executor/models/transformers/transformer_hunyuan_video15.py
    • Imported additional distributed utility functions, including get_runtime_state and get_cfg_group.
    • Removed the _chunk_and_pad_sequence and _gather_and_unpad helper methods from the xFuserHunyuanVideo15Transformer class.
    • Modified the xFuserHunyuanVideo15AttnProcessor.__call__ method to dynamically determine num_encoder_hidden_states_tokens and num_query_tokens.
    • Introduced conditional logic within xFuserHunyuanVideo15AttnProcessor.__call__ to handle sequence parallelism for attention, splitting query, key, and value tensors and using USP with joint_strategy="rear" when split_text_embed_in_sp is false.
    • Updated the xFuserHunyuanVideo15Transformer.forward method to directly chunk hidden_states for sequence parallelism.
    • Filtered encoder_hidden_states based on encoder_attention_mask to remove padded tokens before further processing.
    • Added logic to set get_runtime_state().split_text_embed_in_sp based on whether encoder_hidden_states length is divisible by sp_world_size, and then conditionally chunk encoder_hidden_states.
    • Directly chunked image_rotary_emb components (freqs_cos, freqs_sin) for sequence parallelism.
    • Replaced the encoder_attention_mask argument with None when calling transformer blocks.
    • Replaced the call to _gather_and_unpad with get_sp_group().all_gather for the final hidden states.
Activity
  • The author identified and addressed issue Outputs of the Hunyuan Video 1.5 in xDiT is different from Diffusers #655, which involved incorrect handling of attention masks in sequence parallelism for HunyuanVideo-1.5.
  • The author provided detailed explanations of the problem and the implemented solution in the pull request description.
  • Test results, including visual comparisons (Diffusers vs. xDiT) with specific parameters (width, height, frames, steps, seed), were provided to demonstrate the fix's effectiveness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with sequence parallelism (SP) for HunyuanVideo-1.5 by correctly modifying how encoder_hidden_states are handled, specifically by removing padding and using USP with a joint_strategy when the sequence length is not divisible by the SP world size. However, a critical security concern has been identified: the current batch truncation implementation is flawed as it relies on the mask of the first item in the batch. This can lead to incorrect processing or data loss for other items in the batch with different prompt lengths, particularly in multi-tenant environments where requests are batched together. Furthermore, the implementation introduces a critical regression by removing necessary padding and unpadding logic for hidden_states and image_rotary_emb, which will cause errors during the all_gather operation if the sequence length is not divisible by the SP world size. Minor import and style issues also require attention.

encoder_attention_mask = torch.stack(new_encoder_attention_mask)

# sequence parallel
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_world_rank]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The padding for hidden_states has been removed. If hidden_states.shape[1] is not divisible by sp_world_size, torch.chunk will create uneven chunks across sequence parallel ranks. This will cause the all_gather operation at the end of the forward pass to fail. The padding logic from the previous version should be restored to ensure correctness. You'll also need to store the pad_amount to use for unpadding later.

        hidden_states_pad_amount = (sp_world_size - hidden_states.shape[1] % sp_world_size) % sp_world_size
        if hidden_states_pad_amount > 0:
            pad_shape = list(hidden_states.shape)
            pad_shape[1] = hidden_states_pad_amount
            padding = torch.zeros(pad_shape, dtype=hidden_states.dtype, device=hidden_states.device)
            hidden_states = torch.cat([hidden_states, padding], dim=1)
        hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_world_rank]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_states.shape[1] represents the sequence length of hidden_states, i.e., the number of image tokens. Since HunyuanVideo-1.5 typically generates videos at resolutions like 480p and 720p, this sequence length is usually divisible by common sp_world_size values (e.g., 2, 4, 8, ...).

Comment on lines +306 to +313
encoder_attention_mask = encoder_attention_mask[0].to(torch.bool)
encoder_hidden_states_indices = torch.arange(
encoder_hidden_states.shape[1],
device=encoder_hidden_states.device)
encoder_hidden_states_indices = encoder_hidden_states_indices[
encoder_attention_mask]
encoder_hidden_states = encoder_hidden_states[
..., encoder_hidden_states_indices, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The truncation of encoder_hidden_states based on encoder_attention_mask[0] assumes that all items in a batch have the same number of valid tokens. In a multi-tenant or shared environment where requests with different prompt lengths are batched together, this logic will lead to either truncation of valid tokens for some users (if the first item has a shorter prompt) or inclusion of padding tokens for others (if the first item has a longer prompt). This violates the integrity of the model's output and can be used to influence other users' results.

Specifically, at lines 306-313, the code uses the mask of the first item in the batch to determine the indices of valid tokens for the entire batch. If subsequent items in the batch have more valid tokens than the first item, those tokens will be lost. If they have fewer, padding tokens will be included in the attention computation, which the PR explicitly aimed to avoid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this now breaks support for BS > 1 where the prompts aren't identical in length.

For HunyuanVideo, we are using this:

        encoder_attention_mask = encoder_attention_mask.to(torch.bool).any(dim=0)
        encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask, :]

this should do the same trick while still retaining support for BS > 1.

Comment on lines 10 to 20
from xfuser.core.distributed import (
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a couple of issues with the imports in this block:

  • get_sequence_parallel_world_size is imported twice.
  • get_classifier_free_guidance_world_size, get_classifier_free_guidance_rank, and get_cfg_group are imported but are not used anywhere in this file.

Please clean up the imports to remove the duplicate and the unused functions for better code maintainability.

from xfuser.core.distributed import (
    get_runtime_state,
    get_sequence_parallel_world_size,
    get_sequence_parallel_rank,
    get_sp_group,
    get_pipeline_parallel_world_size,
)

encoder_hidden_states = encoder_hidden_states[
..., encoder_hidden_states_indices, :]

if encoder_hidden_states.shape[1] % sp_world_size!= 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to follow common Python style guides (like PEP 8), please add spaces around the != operator.

        if encoder_hidden_states.shape[1] % sp_world_size != 0:
References
  1. PEP 8 recommends using spaces around binary operators for better readability. (link)

@jcaraban jcaraban requested a review from avjves March 3, 2026 09:11
@avjves
Copy link
Collaborator

avjves commented Mar 4, 2026

Thanks for the fix!

Yeah, the padding we do does have an effect on the output, so this is a great fix for that :)

However, I'm not sure if assumption holds that we can remove padding from hidden_states. While the model behaves the best with 480p/720p videos, it can generate other resolutions as well, especially with the i2v task where the resolution should match the input image. Non-standard resolutions would then crash in the all2all part if the chunks have different shapes. Does the above fix work still if we re-introduce the padding for hidden_states?

@cszhouY
Copy link
Author

cszhouY commented Mar 7, 2026

@avjves

Thank you for your comment.

Padding is necessary when the sequence length is not a multiple of sp_world_size. I've added padding in the new commit. Please note that the padded portion will participate in the attention computation, thereby affecting the numerical results. However, the impact is minimal, as the maximum padding length equals sp_world_size - 1, which is much smaller than the sequence length.

I retested the following video generation resolutions, where the sequence length is no longer a multiple of sp_world_size. Additionally, all attention backends were set to spda_efficient to avoid numerical discrepancies caused by different attention backends. As shown, the videos are generated correctly and are nearly identical to those generated on a single GPU using Diffusers (minor differences still exist due to the reason mentioned above).

width height frames steps seed ulysses_degree
976 592 81 50 42 4
  • xDiT
hunyuan_video_15_t2v_result_ulysses4_ring1_tc_False_592x976.mp4
  • Diffusers
hunyuan_video1.5_592x976_81_steps50.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants