Skip to content

fix: preserve RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES to prevent NCCL NVSwitch bugs#2252

Open
dmvevents wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
dmvevents:fix/preserve-ray-cuda-env
Open

fix: preserve RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES to prevent NCCL NVSwitch bugs#2252
dmvevents wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
dmvevents:fix/preserve-ray-cuda-env

Conversation

@dmvevents
Copy link
Copy Markdown
Contributor

@dmvevents dmvevents commented Apr 12, 2026

Fixes #1963. Related: #1961.

Summary

worker_groups.py removes RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES from the worker environment (line 501), which forces Ray to set per-actor CUDA_VISIBLE_DEVICES masking (e.g., CUDA_VISIBLE_DEVICES=3 for the 4th GPU). This triggers three confirmed NCCL bugs on NVSwitch topologies (H200 P5en, H100 P5):

  1. cuMem import penalty (NVIDIA/nccl#1749) — p2pMap() iterates over all devices when importing cuMem handles with non-overlapping CUDA_VISIBLE_DEVICES. Causes 3,660ms first-operation penalty vs 1.5ms with torchrun.

  2. NVLS rank ordering corruption (NVIDIA/nccl#1906) — allgather in nvls.cc is missing a user rank table when GPU indices are permuted by CUDA_VISIBLE_DEVICES. Causes hang or silent data corruption on NVSwitch systems.

  3. Multi-channel P2P hang at >8M elements — AllReduce hangs for tensors larger than ~32MB even with NCCL_CUMEM_ENABLE=0 and NCCL_NVLS_ENABLE=0.

Fix

Preserve RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 instead of removing it. This tells Ray not to mask CUDA_VISIBLE_DEVICES, so each worker sees all GPUs. Workers use explicit torch.cuda.set_device(local_rank) instead, which mirrors torchrun behavior and works correctly on both NVSwitch and NVLink topologies.

Benchmarks (same hardware, same NCCL)

Method AllReduce 4KB AllReduce 933MB
torchrun (no GPU masking) 1.5ms 1.5ms
Ray (GPU masking forced) 3,660ms HANGS forever

Both tests: P5en.48xlarge (8x H200, NVSwitch), NCCL 2.27.5, EFA, same container.

Test plan

  • Verified fix on 4x P5en.48xlarge (32x H200) — multi-node GRPO training completes successfully
  • Verified no regression on P5.48xlarge (H100, NVLink PXN) — works as before
  • LOCAL_RANK is already set in worker_env_vars (line 492), so workers can torch.cuda.set_device() correctly

Summary by CodeRabbit

  • Bug Fixes
    • Improved distributed training GPU configuration to prevent unintended CUDA device masking and properly preserve GPU device visibility settings across worker processes.

…NCCL NVSwitch bugs

Fixes NVIDIA-NeMo#1963. Related: NVIDIA-NeMo#1961.

The current code removes RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES
from the worker environment, which forces Ray to mask
CUDA_VISIBLE_DEVICES per actor. This triggers known NCCL bugs on
NVSwitch topologies (H200/P5en, H100/P5):

- cuMem import penalty causing 2400x slower first AllReduce (nccl#1749)
- NVLS rank ordering corruption causing hangs (nccl#1906)
- Multi-channel P2P hangs at >8M elements

Fix: set RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 so each worker
sees all GPUs and uses explicit torch.cuda.set_device(local_rank)
instead. This mirrors torchrun behavior and works correctly on both
NVSwitch and NVLink topologies.
@dmvevents dmvevents requested a review from a team as a code owner April 12, 2026 18:22
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 12, 2026

📝 Walkthrough

Walkthrough

The change modifies environment variable handling in worker initialization to explicitly preserve the RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES flag instead of removing it, preventing Ray from applying per-actor GPU masking that causes NCCL issues on H200/NVSwitch hardware.

Changes

Cohort / File(s) Summary
Ray Environment Variable Handling
nemo_rl/distributed/worker_groups.py
Replaced removal of RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES with explicit assignment of "1" to preserve Ray's no-GPU-masking behavior. Maintains removal of other Ray-specific environment variables (RAY_CLIENT_MODE, RAY_JOB_ID, RAY_LD_PRELOAD, RAY_RAYLET_PID, RAY_USAGE_STATS_ENABLED).

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR fixes critical GPU/NCCL issue on NVSwitch hardware affecting performance and correctness, but test results are not formally documented in PR description despite being mentioned in objectives. Add formal Testing section to PR description documenting P5en and P5 benchmark results, regression testing outcomes, and GRPO multi-node training verification details.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: preserving RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES to prevent NCCL NVSwitch bugs, which is the core fix in the changeset.
Linked Issues check ✅ Passed The PR directly addresses issue #1963 by preserving RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1, preventing Ray GPU masking and avoiding the three confirmed NCCL bugs on NVSwitch hardware, meeting all stated coding objectives.
Out of Scope Changes check ✅ Passed All changes are scoped to the specific issue: modification of worker environment handling in _create_workers_from_bundle_indices to preserve the environment variable while removing unrelated Ray env vars unchanged.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
nemo_rl/distributed/worker_groups.py (1)

506-506: Use setdefault() to preserve caller-provided overrides while maintaining the safe default.

Line 506 unconditionally overwrites any value for RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, contradicting the comment which says "Preserve". The current merge pattern (lines 430–433) respects caller-provided env_vars, and the pattern of explicit pop() calls (lines 508–512) shows selective env var management. Using setdefault() aligns the implementation with both the comment's intent and the method's design pattern of respecting caller values while providing a safe default for this critical NCCL bug workaround.

♻️ Suggested refactor
-                worker_env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1"
+                worker_env_vars.setdefault(
+                    "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1"
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/distributed/worker_groups.py` at line 506, The code unconditionally
overwrites RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES on worker_env_vars which
violates the "Preserve" intent; change the assignment to use
worker_env_vars.setdefault("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
so that any caller-provided value is preserved while still providing the safe
default for the NCCL workaround (locate the usage of worker_env_vars and the
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES symbol in worker_groups.py).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@nemo_rl/distributed/worker_groups.py`:
- Line 506: The code unconditionally overwrites
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES on worker_env_vars which violates
the "Preserve" intent; change the assignment to use
worker_env_vars.setdefault("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
so that any caller-provided value is preserved while still providing the safe
default for the NCCL workaround (locate the usage of worker_env_vars and the
RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES symbol in worker_groups.py).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fd3c68a0-4013-4655-bc91-decde792ee45

📥 Commits

Reviewing files that changed from the base of the PR and between 69d1872 and 42f95f6.

📒 Files selected for processing (1)
  • nemo_rl/distributed/worker_groups.py

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Apr 13, 2026

hi @guyueh1 @terrykong , could you help to take a review to check if this will have some side effects?

@yuki-97 yuki-97 requested review from guyueh1 and terrykong April 13, 2026 03:37
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The approach (preserving RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1) is sound and consistent with how vLLM and SGLang workers already behave. However, there is a critical regression in the default DTensor V1 path and several secondary issues that need attention before merge.

Historical context: The original pop() was introduced in PR #432 by @hemildesai for ray job submit support. After investigation, the KubeRay risk is low — NeMo-RL uses whole-node GPU allocations, and vLLM/SGLang already set this env var to 1.

# penalty (nccl#1749) and NVLS rank ordering corruption (nccl#1906).
# Workers use explicit torch.cuda.set_device(local_rank) instead.
# See: https://github.com/NVIDIA-NeMo/RL/issues/1963
worker_env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: DTensor V1 LOCAL_RANK=0 regression

This unconditional = "1" makes all GPUs visible to every worker. However, dtensor_policy_worker.py:320-329 has a LOCAL_RANK=0 hack that assumes only 1 GPU is visible:

# torch==2.8 uses LOCAL_RANK to set the device here
# but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0.
prev_local_rank = os.environ["LOCAL_RANK"]
os.environ["LOCAL_RANK"] = "0"
device_mesh = torch.distributed.device_mesh.init_device_mesh(...)

With all GPUs visible, init_device_mesh calls set_device(0) for every worker → all workers fight over GPU 0 → OOM. DTensor V1 is the default (lm_policy.py:113: _v2=False). DTensor V2 is unaffected (uses FSDP2Manager which reads real LOCAL_RANK).

Fix: Remove the LOCAL_RANK=0 hack (lines 323-324). With all GPUs visible, init_device_mesh should use the real LOCAL_RANK=bundle_idx.

# from masking CUDA_VISIBLE_DEVICES per actor. GPU masking triggers NCCL
# bugs on NVSwitch topologies (H200/P5en, H100/P5) including cuMem import
# penalty (nccl#1749) and NVLS rank ordering corruption (nccl#1906).
# Workers use explicit torch.cuda.set_device(local_rank) instead.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is inaccurate — no production nemo_rl code calls torch.cuda.set_device(local_rank). Device binding happens implicitly via init_device_mesh / Megatron internals reading the LOCAL_RANK env var.

Suggested change
# Workers use explicit torch.cuda.set_device(local_rank) instead.
# Workers rely on LOCAL_RANK env var for device selection via
# init_device_mesh / Megatron internals.

@terrykong
Copy link
Copy Markdown
Collaborator

Additional Findings

vLLM Non-Parallel GPU Collision Risk

vllm_worker.py:362-364 — Non-parallel vLLM workers (TP=1) see all 8 GPUs but don't explicitly select a device. Multiple DP replicas on one node may collide on GPU 0. SGLang handles this via base_gpu_id (sglang_worker.py:120,123) — vLLM's non-parallel path lacks equivalent logic.

Stale Comments

These comments reference CUDA_VISIBLE_DEVICES masking behavior that no longer applies with this change:

Megatron Seed Computation Change

megatron_policy_worker.py:777-780torch.cuda.device_count() returns 8 instead of 1, changing seed values. The new computation is actually more correct semantically (proper node_idx), but breaks reproducibility vs. previous runs. Worth noting in the PR description.

@terrykong
Copy link
Copy Markdown
Collaborator

Re: CodeRabbit's setdefault() suggestion — disagree.

The unconditional = "1" is intentional. This is a safety invariant for NCCL correctness, not a default. If a caller had RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=0 in their env, setdefault() would silently not apply the fix, defeating the entire purpose of the PR.

Every other site in the codebase that sets this variable uses unconditional assignment:

The word "Preserve" in the comment means "keep this env var present" (vs. the old pop()), not "defer to caller-provided values."

@terrykong
Copy link
Copy Markdown
Collaborator

hey @dmvevents. the above review was claude generated so we should just use it as a guide. i think this change makes sense, but we'll need to run the nightly CI to make sure this doesn't break anything since this change has a large impact

@chtruong814 chtruong814 added waiting-for-customer Waiting for response from the original author and removed waiting-for-customer Waiting for response from the original author labels Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

worker_groups.py removes RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, breaking NCCL on H200/NVSwitch

5 participants