Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def forward(
vision_grid_thw = None
vision_data = None
vision_mask = None
vision_embeds = None
deepstack_feature_lists = None

# position ids is computed within the model
Expand Down
51 changes: 39 additions & 12 deletions src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL
vp_stage=vp_stage,
)
_patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention)
mtp_spec = mtp_block_spec(self, vp_stage=vp_stage)
_patch_standard_attention_specs(mtp_spec, Qwen3VLSelfAttention)

model = Qwen3VLModel(
language_transformer_config=language_transformer_config,
Expand All @@ -202,7 +204,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL
pre_process=pre_process,
post_process=post_process,
pg_collection=self._pg_collection,
mtp_block_spec=mtp_block_spec(self, vp_stage=vp_stage),
mtp_block_spec=mtp_spec,
vp_stage=vp_stage,
)

Expand Down Expand Up @@ -382,6 +384,8 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL
# Selectively patch only the standard (full) attention layer specs
# with Qwen3VLSelfAttention for mRoPE support. GDN layers are left as-is.
_patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention)
mtp_spec = mtp_block_spec(self, vp_stage=vp_stage)
_patch_standard_attention_specs(mtp_spec, Qwen3VLSelfAttention)

model = Qwen3VLModel(
language_transformer_config=language_transformer_config,
Expand All @@ -390,7 +394,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL
pre_process=pre_process,
post_process=post_process,
pg_collection=self._pg_collection,
mtp_block_spec=mtp_block_spec(self, vp_stage=vp_stage),
mtp_block_spec=mtp_spec,
vp_stage=vp_stage,
)

Expand All @@ -410,10 +414,13 @@ def provide_language_model(self, pre_process=None, post_process=None, vp_stage=N


def _patch_standard_attention_specs(
block_spec: TransformerBlockSubmodules,
block_spec: Optional[TransformerBlockSubmodules | ModuleSpec],
attention_cls,
) -> None:
"""Selectively replace the self_attention module on standard attention layer specs.
"""Selectively replace standard self-attention specs with ``attention_cls``.

This handles both the main decoder block spec and the nested TransformerLayer
spec stored inside MTP block specs.

In a hybrid block spec, each layer spec has a different self_attention submodule:
- Standard attention layers have a ``SelfAttention``-like module.
Expand All @@ -428,11 +435,31 @@ def _patch_standard_attention_specs(
"""
from megatron.core.transformer.attention import SelfAttention

for layer_spec in block_spec.layer_specs:
attn_spec = layer_spec.submodules.self_attention
# Standard attention specs use SelfAttention (or a subclass) as the module
# and have linear_qkv in their submodules. GDN specs use GatedDeltaNet.
if attn_spec.module is SelfAttention or (
isinstance(attn_spec.module, type) and issubclass(attn_spec.module, SelfAttention)
):
attn_spec.module = attention_cls
if block_spec is None:
return

if hasattr(block_spec, "layer_specs"):
for layer_spec in block_spec.layer_specs:
_patch_standard_attention_specs(layer_spec, attention_cls)
return

if not isinstance(block_spec, ModuleSpec):
return

submodules = getattr(block_spec, "submodules", None)
if submodules is None:
return

if hasattr(submodules, "mtp_model_layer"):
_patch_standard_attention_specs(submodules.mtp_model_layer, attention_cls)

if not hasattr(submodules, "self_attention"):
return

attn_spec = submodules.self_attention
# Standard attention specs use SelfAttention (or a subclass) as the module.
# and have linear_qkv in their submodules. GDN specs use GatedDeltaNet.
if attn_spec.module is SelfAttention or (
Comment thread
HollowMan6 marked this conversation as resolved.
isinstance(attn_spec.module, type) and issubclass(attn_spec.module, SelfAttention)
):
attn_spec.module = attention_cls
76 changes: 76 additions & 0 deletions tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import datetime
import os
from dataclasses import replace
from types import SimpleNamespace

import numpy as np
import pytest
Expand Down Expand Up @@ -527,6 +528,81 @@ def test_forward_dist_train_decoder_only(self, hf_config, processor, random_imag
assert isinstance(out, torch.Tensor)
assert out.dim() >= 2

def test_forward_text_only_without_vision_inputs(self, monkeypatch):
"""Text-only forward should not require vision_embeds to be materialized."""

monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model.reorganize_inputs",
lambda **_kwargs: (None, None, None),
)
monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model.get_rope_index",
lambda *args, **kwargs: (
torch.zeros((3, args[4].shape[0], args[4].shape[1]), dtype=torch.long),
None,
),
)
monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model.torch.cuda.nvtx.range_push",
lambda *_args, **_kwargs: None,
)
monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model.torch.cuda.nvtx.range_pop",
lambda *_args, **_kwargs: None,
)

class DummyLanguageModel:
def __init__(self):
self.rotary_pos_emb = SimpleNamespace(is_thd_format=False)
self.last_kwargs = None

def embedding(self, input_ids, position_ids=None):
del position_ids
batch_size, seq_len = input_ids.shape
return torch.zeros((seq_len, batch_size, 4), dtype=torch.float32)

def __call__(self, **kwargs):
self.last_kwargs = kwargs
return torch.ones(1)

language_model = DummyLanguageModel()
model = SimpleNamespace(
pre_process=True,
square_merge_size=4,
config=SimpleNamespace(
vision_dp_when_cp=False,
sequence_parallel=False,
spatial_merge_size=4,
),
pg_collection=SimpleNamespace(
cp=SimpleNamespace(rank=lambda: 0, size=lambda: 1),
tp=SimpleNamespace(rank=lambda: 0, size=lambda: 1),
pp=object(),
),
language_model=language_model,
image_token_id=1,
video_token_id=2,
vision_start_token_id=3,
use_dist_train=False,
)

input_ids = torch.tensor([[11, 12]], dtype=torch.long)

output = Qwen3VLModel.forward(
model,
input_ids=input_ids,
attention_mask=None,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
)

assert torch.equal(output, torch.ones(1))
assert language_model.last_kwargs is not None
assert language_model.last_kwargs["visual_pos_masks"] is None
assert language_model.last_kwargs["decoder_input"].shape == (2, 1, 4)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Qwen3VLModel.forward requires CUDA")
@pytest.mark.timeout(120)
def test_forward_non_dist_train(self, hf_config, processor, random_image):
Expand Down
62 changes: 62 additions & 0 deletions tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import SimpleNamespace
from unittest.mock import Mock

import pytest
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.spec_utils import ModuleSpec

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.qwen_vl.qwen35_vl_provider import (
_TRANSFORMERS_HAS_QWEN3_5,
_TRANSFORMERS_HAS_QWEN3_5_MOE,
Qwen3VLSelfAttention,
Qwen35VLModelProvider,
Qwen35VLMoEModelProvider,
_patch_standard_attention_specs,
)


Expand Down Expand Up @@ -146,6 +153,16 @@ def test_provide_methods_exist(self):
assert hasattr(provider, "provide") and callable(provider.provide)
assert hasattr(provider, "provide_language_model") and callable(provider.provide_language_model)

def test_patch_standard_attention_specs_recurses_into_mtp_specs(self):
attn_spec = ModuleSpec(module=SelfAttention, submodules=SimpleNamespace())
mtp_model_layer = ModuleSpec(module=object, submodules=SimpleNamespace(self_attention=attn_spec))
mtp_layer = ModuleSpec(module=object, submodules=SimpleNamespace(mtp_model_layer=mtp_model_layer))
mtp_block = SimpleNamespace(layer_specs=[mtp_layer])

_patch_standard_attention_specs(mtp_block, Qwen3VLSelfAttention)

assert mtp_model_layer.submodules.self_attention.module is Qwen3VLSelfAttention


@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support")
class TestQwen35VLMoEModelProvider:
Expand Down Expand Up @@ -220,3 +237,48 @@ def test_vision_config_default_type(self):
num_attention_heads=32,
)
assert isinstance(provider.vision_config, Qwen3_5MoeVisionConfig)

def test_provide_patches_mtp_attention_spec(self, monkeypatch):
block_attn_spec = ModuleSpec(module=SelfAttention, submodules=SimpleNamespace())
mtp_attn_spec = ModuleSpec(module=SelfAttention, submodules=SimpleNamespace())
block_spec = SimpleNamespace(
layer_specs=[ModuleSpec(module=object, submodules=SimpleNamespace(self_attention=block_attn_spec))]
)
mtp_spec = SimpleNamespace(
layer_specs=[
ModuleSpec(
module=object,
submodules=SimpleNamespace(
mtp_model_layer=ModuleSpec(
module=object,
submodules=SimpleNamespace(self_attention=mtp_attn_spec),
)
),
)
]
)
model_ctor = Mock(return_value=Mock())

monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.qwen35_vl_provider.get_transformer_block_with_experimental_attention_variant_spec",
lambda *args, **kwargs: block_spec,
)
monkeypatch.setattr("megatron.bridge.models.gpt_provider.mtp_block_spec", lambda *args, **kwargs: mtp_spec)
monkeypatch.setattr("megatron.bridge.models.qwen_vl.qwen35_vl_provider.Qwen3VLModel", model_ctor)

provider = Qwen35VLMoEModelProvider(
num_layers=60,
hidden_size=4096,
num_attention_heads=32,
mtp_num_layers=1,
)
provider.provide()

kwargs = model_ctor.call_args.kwargs
assert kwargs["language_transformer_layer_spec"].layer_specs[0].submodules.self_attention.module is (
Qwen3VLSelfAttention
)
assert (
kwargs["mtp_block_spec"].layer_specs[0].submodules.mtp_model_layer.submodules.self_attention.module
is Qwen3VLSelfAttention
)
Loading