Skip to content

[recipe] feat: add Qwen3-0.6B 128K SFT recipe with YaRN RoPE scaling#3316

Open
RayenTian wants to merge 2 commits intomainfrom
ruit/128k_sft_recipe
Open

[recipe] feat: add Qwen3-0.6B 128K SFT recipe with YaRN RoPE scaling#3316
RayenTian wants to merge 2 commits intomainfrom
ruit/128k_sft_recipe

Conversation

@RayenTian
Copy link
Copy Markdown
Contributor

@RayenTian RayenTian commented Apr 14, 2026

What does this PR do ?

Adds a new SFT training recipe for Qwen3-0.6B at 128K context length using YaRN RoPE scaling, together with a reference launch script.

YaRN scaling extends Qwen3-0.6B's native 40K context window to 128K:

  • yarn_rotary_scaling_factor = 128K / 40K = 3.2
  • yarn_beta_fast = 32.0, yarn_beta_slow = 1.0

Long-context SFT stability settings:

  • cross_entropy_loss_fusion = False
  • calculate_per_token_loss = True
  • ddp.average_in_collective = False
  • use_distributed_optimizer = False

Dataset

Specific data file to avoid too mush time spend on data download.
nvidia/Nemotron-Cascade-2-SFT-Data

  • subset: math
  • split:train
  • file: math/math_notool.jsonl
  • first 10000 samples

Overview of dataset

=== Sequence length stats over 10000 examples ===
  min   : 350
  max   : 140459
  mean  : 18371.6
  median: 12432.5
  p90   : 43089
  p95   : 57582
  p99   : 84679
  > 40K : 1104 (11.0%) # 40k is the original seqlen of qwen3-0.6B
  > 128K: 1 (0.0%)

Result

image

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features
    • Added supervised fine-tuning support for Qwen3 600M model with 128K extended sequence length for long-context training.
    • Provided training script and optimized configuration for distributed fine-tuning workflows.
    • Configured with advanced position scaling and specialized optimizations for long-context scenarios.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@RayenTian RayenTian requested a review from yaoyu-33 April 14, 2026 08:28
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

Adds a new long-context supervised fine-tuning recipe for Qwen3-0.6B with YaRN RoPE scaling at 128K sequence length. Includes a Bash training script, exported recipe configuration, and a recipe factory function that configures the model, dataset with HuggingFace integration, context parallelism, and long-context training hyperparameters.

Changes

Cohort / File(s) Summary
Training Script
examples/long_context/qwen3_600m_sft_yarn_128k.sh
New executable Bash script that configures and launches distributed fine-tuning using torch.distributed.run with 8 processes per node, setting workspace paths, batch sizes, iteration counts, warmup schedules, and WandB logging.
Recipe Export
src/megatron/bridge/recipes/qwen/__init__.py
Added qwen3_600m_sft_yarn_128k_config to __all__ exports for public access.
Recipe Implementation
src/megatron/bridge/recipes/qwen/qwen3.py
New qwen3_600m_sft_yarn_128k_config() factory function that inherits from qwen3_600m_sft_config() and configures YaRN RoPE scaling, 128K sequence length, context parallelism (size=8), HuggingFace dataset (nvidia/Nemotron-Cascade-2-SFT-Data/math subset), and adjusted long-context training parameters (global batch size, optimizer settings, DDP configuration).

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Possibly related PRs

  • PR #2951: Adds a complementary non-YaRN qwen3_600m_sft_128k_config variant for 128K context, sharing the same configuration infrastructure and recipe updates.

Suggested labels

area:model, area:recipe

Suggested reviewers

  • cuichenx
🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and clearly describes the main change: adding a new Qwen3-0.6B recipe with 128K context length and YaRN RoPE scaling, which matches all three modified files.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR introduces major changes (new 128K context SFT recipe with YaRN scaling) and includes comprehensive training metric visuals documenting convergence, gradient behavior, and validation performance without overfitting indicators.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ruit/128k_sft_recipe

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/long_context/qwen3_600m_sft_yarn_128k.sh (2)

26-26: Unused variable SEQ_LENGTH.

SEQ_LENGTH is defined but never referenced in the script. The sequence length is embedded in the recipe itself (128*1024). Consider removing this variable or adding a comment explaining it's for documentation purposes.

Suggested fix
-SEQ_LENGTH=131072
+# Sequence length is configured in the recipe (128K = 131072)

Or simply remove the line if it serves no purpose.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/long_context/qwen3_600m_sft_yarn_128k.sh` at line 26, The variable
SEQ_LENGTH is defined but never used; either remove the SEQ_LENGTH line or make
its intent explicit by converting it into a documented constant/comment (e.g.,
note that the recipe uses 128*1024) so it isn't unused — update the SEQ_LENGTH
declaration accordingly or delete it to eliminate the dead variable.

37-50: Consider quoting variable expansions to prevent word splitting.

Static analysis flagged unquoted variables. While these controlled variables are unlikely to cause issues, quoting prevents unexpected behavior if paths contain spaces.

Suggested fix
 uv run python -m torch.distributed.run --nproc_per_node=8 scripts/training/run_recipe.py \
-    --recipe ${MODEL_NAME}_sft_yarn_128k_config \
-    checkpoint.pretrained_checkpoint=$PRETRAINED_CHECKPOINT \
+    --recipe "${MODEL_NAME}_sft_yarn_128k_config" \
+    checkpoint.pretrained_checkpoint="$PRETRAINED_CHECKPOINT" \
     train.train_iters=$TRAIN_ITERS \
     train.global_batch_size=$GLOBAL_BATCH_SIZE \
     train.micro_batch_size=$MICRO_BATCH_SIZE \
     validation.eval_iters=$EVAL_ITERS \
     validation.eval_interval=$EVAL_INTERVAL \
     scheduler.lr_warmup_iters=$LR_WARMUP_ITERS \
-    checkpoint.save=${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft \
-    checkpoint.load=${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft \
+    checkpoint.save="${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft" \
+    checkpoint.load="${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft" \
     logger.log_interval=$LOG_INTERVAL \
-    logger.wandb_project=$WANDB_PROJECT \
-    logger.wandb_exp_name=${MODEL_NAME}_${DATASET_NAME}_yarn_128k_sft
+    logger.wandb_project="$WANDB_PROJECT" \
+    logger.wandb_exp_name="${MODEL_NAME}_${DATASET_NAME}_yarn_128k_sft"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/long_context/qwen3_600m_sft_yarn_128k.sh` around lines 37 - 50, The
shell command uses unquoted variable expansions (e.g.,
${MODEL_NAME}_sft_yarn_128k_config, $PRETRAINED_CHECKPOINT, $TRAIN_ITERS,
$GLOBAL_BATCH_SIZE, $MICRO_BATCH_SIZE, $EVAL_ITERS, $EVAL_INTERVAL,
$LR_WARMUP_ITERS, ${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft,
$LOG_INTERVAL, $WANDB_PROJECT, ${MODEL_NAME}_${DATASET_NAME}_yarn_128k_sft)
which can cause word-splitting if any contain spaces; update the invocation in
the uv run / python -m torch.distributed.run command to wrap each variable
expansion in double quotes (e.g., "$PRETRAINED_CHECKPOINT",
"${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft", etc.) so all arguments are
passed as single tokens.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/recipes/qwen/qwen3.py`:
- Around line 659-660: The block that sets context parallelism only assigns
cfg.model.context_parallel_size = 8 but omits the CP communication mode; mirror
the earlier recipe (qwen3_600m_sft_128k_config) by also setting
cfg.model.cp_comm_type = "a2a" to use all-to-all CP and avoid NaN gradients—add
cfg.model.cp_comm_type = "a2a" alongside cfg.model.context_parallel_size = 8 in
the same config function/block.

---

Nitpick comments:
In `@examples/long_context/qwen3_600m_sft_yarn_128k.sh`:
- Line 26: The variable SEQ_LENGTH is defined but never used; either remove the
SEQ_LENGTH line or make its intent explicit by converting it into a documented
constant/comment (e.g., note that the recipe uses 128*1024) so it isn't unused —
update the SEQ_LENGTH declaration accordingly or delete it to eliminate the dead
variable.
- Around line 37-50: The shell command uses unquoted variable expansions (e.g.,
${MODEL_NAME}_sft_yarn_128k_config, $PRETRAINED_CHECKPOINT, $TRAIN_ITERS,
$GLOBAL_BATCH_SIZE, $MICRO_BATCH_SIZE, $EVAL_ITERS, $EVAL_INTERVAL,
$LR_WARMUP_ITERS, ${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft,
$LOG_INTERVAL, $WANDB_PROJECT, ${MODEL_NAME}_${DATASET_NAME}_yarn_128k_sft)
which can cause word-splitting if any contain spaces; update the invocation in
the uv run / python -m torch.distributed.run command to wrap each variable
expansion in double quotes (e.g., "$PRETRAINED_CHECKPOINT",
"${WORKSPACE}/results/${MODEL_NAME}_yarn_128k_sft", etc.) so all arguments are
passed as single tokens.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: c1053130-9c1f-4032-83c2-5af417cafb7e

📥 Commits

Reviewing files that changed from the base of the PR and between ad27e2c and 119631c.

📒 Files selected for processing (3)
  • examples/long_context/qwen3_600m_sft_yarn_128k.sh
  • src/megatron/bridge/recipes/qwen/__init__.py
  • src/megatron/bridge/recipes/qwen/qwen3.py

Comment thread src/megatron/bridge/recipes/qwen/qwen3.py
@RayenTian
Copy link
Copy Markdown
Contributor Author

/ok to test 119631c

Comment thread src/megatron/bridge/recipes/qwen/qwen3.py Outdated
Comment thread src/megatron/bridge/recipes/qwen/qwen3.py Outdated
Signed-off-by: ruit <ruit@nvidia.com>
@RayenTian RayenTian force-pushed the ruit/128k_sft_recipe branch from 119631c to 9bbddfe Compare April 15, 2026 07:44
@yaoyu-33
Copy link
Copy Markdown
Contributor

/ok to test 9bbddfe

@yaoyu-33 yaoyu-33 added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants