Skip to content

[DSV3] Fix the ckpt loading issue when no MoE layer on the mtp rank#3315

Open
gdengk wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
gdengk:gaod/dsv3/mlper_ckpt_issue
Open

[DSV3] Fix the ckpt loading issue when no MoE layer on the mtp rank#3315
gdengk wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
gdengk:gaod/dsv3/mlper_ckpt_issue

Conversation

@gdengk
Copy link
Copy Markdown
Contributor

@gdengk gdengk commented Apr 14, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • Bug Fixes
    • Fixed handling of Mixed-Expert layer configuration selection in transformer architecture to correctly infer the appropriate decoder layer type.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Gao Deng <gdeng@login-lyris02.lyris.clusters.nvidia.com>
@gdengk gdengk force-pushed the gaod/dsv3/mlper_ckpt_issue branch from 3759e55 to 3fa93de Compare April 14, 2026 07:25
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

The pull request modifies layer specification derivation logic in mtp_block_spec to handle MoE/block-spec scenarios by explicitly deriving GPT decoder layer specs and selecting the final layer type instead of applying a default, ensuring the correct decoder type (dense vs MoE) is used for the MTP transformer layer.

Changes

Cohort / File(s) Summary
MTP Block Specification Logic
src/megatron/bridge/models/gpt_provider.py
Modified mtp_block_spec to derive GPT decoder layer specs explicitly when layer_specs is present but empty, selecting the final layer specification to determine the correct decoder type (dense vs MoE) for the MTP block.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the specific issue being fixed: checkpoint loading when no MoE layer exists on the MTP rank, which aligns with the code change that corrects decoder layer spec selection for MoE scenarios.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR meets minor changes criterion with +11/-1 lines fixing MoE checkpoint loading edge case.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/models/gpt_provider.py`:
- Around line 365-370: The GPTModelProvider dataclass is missing the qk_l2_norm
attribute but the code in get_gpt_decoder_layer_specs accesses
config.qk_l2_norm; add a new field to the GPTModelProvider dataclass named
qk_l2_norm: bool = False (matching the style/pattern of qk_layernorm and fp8) so
instances have this attribute and runtime AttributeError is avoided; update any
dataclass default initialization or constructor usage as needed to include the
new field.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: bb23f6f8-8f8b-442c-b0b5-75eda2b121db

📥 Commits

Reviewing files that changed from the base of the PR and between ad27e2c and 3fa93de.

📒 Files selected for processing (1)
  • src/megatron/bridge/models/gpt_provider.py

Comment on lines +365 to +370
decoder_layer_specs = get_gpt_decoder_layer_specs(
config,
use_transformer_engine=True,
normalization=config.normalization,
qk_l2_norm=config.qk_l2_norm,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Check qk-related fields/usage in gpt_provider.py"
rg -n -C2 '\bqk_(layernorm|l2_norm)\b' src/megatron/bridge/models/gpt_provider.py

echo
echo "2) Find TransformerConfig definitions and qk fields"
fd -i 'transformer_config.py' | while read -r f; do
  echo "=== $f ==="
  rg -n -C2 '\bqk_(layernorm|l2_norm)\b' "$f" || true
done

echo
echo "3) Inspect get_gpt_decoder_layer_specs definition/signature"
rg -n -C4 'def\s+get_gpt_decoder_layer_specs\s*\('

echo
echo "4) Compare argument naming at all call sites"
rg -n -C2 'get_gpt_decoder_layer_specs\s*\('

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1626


🏁 Script executed:

# Search for get_gpt_decoder_layer_specs definition more broadly
rg -n 'def\s+get_gpt_decoder_layer_specs' --type py

# Search for qk_l2_norm references anywhere in the codebase
rg -n 'qk_l2_norm' --type py

# Check the TransformerConfig class in gpt_provider.py more completely
sed -n '165,190p' src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1749


🏁 Script executed:

# Search for get_gpt_decoder_layer_specs definition across the repo
find . -type f -name "*.py" -exec grep -l "def get_gpt_decoder_layer_specs" {} \;

# Also check imports in gpt_provider.py to find where this function comes from
grep -n "get_gpt_decoder_layer_specs" src/megatron/bridge/models/gpt_provider.py | head -20

# Check the imports at the top of gpt_provider.py
head -50 src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2201


🏁 Script executed:

# Find all usages of get_gpt_decoder_layer_specs in the repo
rg -n -B2 -A5 'get_gpt_decoder_layer_specs\s*\(' src/megatron/bridge --type py

# Check if there's a parent/base TransformerConfig that defines qk_l2_norm
rg -n 'qk_l2_norm' src/megatron/bridge/models/transformer_config.py

# Check GPTModelProvider config class more comprehensively
sed -n '150,210p' src/megatron/bridge/models/gpt_provider.py

# Look at the llama4_utils.py usage to understand the pattern
rg -n -B5 -A5 'qk_l2_norm' src/megatron/bridge/models/llama/llama4_utils.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4172


🏁 Script executed:

# Check the base class of GPTModelProvider
grep -n "class GPTModelProvider" src/megatron/bridge/models/gpt_provider.py

# Check if there's a parent config that might define qk_l2_norm
sed -n '140,180p' src/megatron/bridge/models/gpt_provider.py

# Check the full dataclass definition
rg -n "@dataclass|class GPTModelProvider" src/megatron/bridge/models/gpt_provider.py -A1

# Look for any initialization or property definition of qk_l2_norm
rg -n "qk_l2_norm" src/megatron/bridge/models/gpt_provider.py

# Check if ModelProviderMixin or parent classes define this field
grep -rn "class ModelProviderMixin" src/megatron/bridge/models/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2339


🏁 Script executed:

# Check the base TransformerConfig class from megatron.bridge.models.transformer_config
cat src/megatron/bridge/models/transformer_config.py

# Also check if there's a relationship between qk_l2_norm and qk_layernorm
rg -n "qk_" src/megatron/bridge/models/transformer_config.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 11837


🏁 Script executed:

# Check if qk_l2_norm is defined in megatron.core by looking at the function signature
# that calls get_gpt_decoder_layer_specs in other files to understand the expected parameter
cat src/megatron/bridge/training/mlm_compat/model.py | head -120

# Also check if there's any fallback or default logic for qk_l2_norm
rg -n "getattr.*qk_l2_norm|hasattr.*qk_l2_norm" --type py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4966


🏁 Script executed:

# Look at the actual function signature or docstring for get_gpt_decoder_layer_specs
# by checking what parameters it accepts in different contexts
rg -n "def get_gpt_decoder_layer_specs|def get_gpt_decoder_block_spec" --type py -A 10

# Check what fields args has in mlm_compat
rg -n "add_argument.*qk_l2_norm|add_argument.*qk_layernorm" --type py -A 2 -B 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check the arguments defined in mlm_compat to see what fields it expects
rg -n "qk_l2_norm|qk_layernorm" src/megatron/bridge/training/mlm_compat/arguments.py -B 2 -A 2

# Check if there's any documentation or patterns for how config fields are managed
rg -n "class.*TransformerConfig|@dataclass" src/megatron/bridge/models/gpt_provider.py | head -5

# Look for any property or __getattr__ definitions that might create qk_l2_norm dynamically
rg -n "property|__getattr__|def __init__" src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 260


🏁 Script executed:

# Check if qk_l2_norm might be inherited from megatron.core.transformer.transformer_config.TransformerConfig
# by searching for it in megatron imports or documentation
rg -n "qk_l2_norm" src/megatron/bridge/models/ -l

# Also check the exact inheritance chain and what TransformerConfig from megatron.core provides
# by looking at how it's used in mlm_compat where both qk_layernorm and qk_l2_norm are passed
sed -n '40,80p' src/megatron/bridge/training/mlm_compat/model.py

# Check if the args object in mlm_compat defines qk_l2_norm
sed -n '1,200p' src/megatron/bridge/training/mlm_compat/arguments.py | grep -A 5 -B 5 "qk_l2_norm"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1602


🏁 Script executed:

# Check the full arguments.py file to see if qk_l2_norm is defined as an argument
cat src/megatron/bridge/training/mlm_compat/arguments.py | grep -n "qk_" -A 2 -B 2

# Also check what _transformer_config_from_args does to understand the config construction
rg -n "_transformer_config_from_args" src/megatron/bridge/training/mlm_compat/arguments.py -A 30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1811


🏁 Script executed:

# Check if qk_l2_norm is defined in megatron.core.transformer.transformer_config
# by searching for references in the entire repo
rg -n "class TransformerConfig" src/megatron/bridge/models/transformer_config.py -A 50 | head -100

# Look at line 75-77 of arguments.py more carefully - it copies fields that exist in args
# Let's see what args defines
rg -n "add_argument.*qk" src/megatron/bridge/training/mlm_compat/arguments.py

# Check if qk_l2_norm is added anywhere in the argument parser
rg -rn "add_argument.*qk_l2_norm" src/megatron/bridge/training/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2411


🏁 Script executed:

# Let's check what MCoreTransformerConfig defines for qk fields
# by examining what fields are expected in the inherited class
rg -n "qk_l2_norm|qk_layernorm" tests/unit_tests/training/mlm_compat/test_model.py -B 3 -A 3

# Also look at what the actual error would be if config.qk_l2_norm doesn't exist
# Check if there's a fallback or default mechanism
sed -n '360,375p' src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1836


🏁 Script executed:

# Check the exact line 369 context and what config is being used there
sed -n '355,375p' src/megatron/bridge/models/gpt_provider.py

# Check if GPTModelProvider or its base TransformerConfig class actually has qk_l2_norm field
# by examining what fields it exposes
rg -n "qk_l2_norm.*:" src/megatron/bridge/models/gpt_provider.py

# Check if there's a default or if it's only conditionally added
rg -n "qk_l2_norm\s*=" src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1166


🏁 Script executed:

# Check if qk_l2_norm is used with getattr elsewhere for safe access
rg -n "getattr.*qk_l2_norm|getattr.*qk_layernorm" --type py

# Look at how other config accesses are done in the same MoE fallback section
sed -n '350,380p' src/megatron/bridge/models/gpt_provider.py

# Verify what fields are in GPTModelProvider vs what might be inherited
rg -n "^\s+[a-z_]+:\s+(Optional\[)?[a-zA-Z]" src/megatron/bridge/models/gpt_provider.py | grep -E "qk|norm|fp8"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1603


🏁 Script executed:

# Check if qk_l2_norm should be added as a field to GPTModelProvider
# by looking at all fields defined in the class
sed -n '124,230p' src/megatron/bridge/models/gpt_provider.py | grep -n ":" | head -30

# Also check if there's documentation or comments about this field being intentionally missing
rg -n "qk_l2_norm|qk_layernorm" src/megatron/bridge/models/gpt_provider.py

# Look at llama4_utils to see if it defines qk_l2_norm or just uses it
sed -n '140,160p' src/megatron/bridge/models/llama/llama4_utils.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2896


Add missing qk_l2_norm field to GPTModelProvider config.

The code at line 369 accesses config.qk_l2_norm, but this field is not defined in the GPTModelProvider dataclass. Only qk_layernorm is defined (line 180). Direct attribute access will raise AttributeError at runtime if the config instance lacks this field. Add qk_l2_norm: bool = False to the dataclass to match the pattern of related fields like qk_layernorm and fp8.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/gpt_provider.py` around lines 365 - 370, The
GPTModelProvider dataclass is missing the qk_l2_norm attribute but the code in
get_gpt_decoder_layer_specs accesses config.qk_l2_norm; add a new field to the
GPTModelProvider dataclass named qk_l2_norm: bool = False (matching the
style/pattern of qk_layernorm and fp8) so instances have this attribute and
runtime AttributeError is avoided; update any dataclass default initialization
or constructor usage as needed to include the new field.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant