-
Notifications
You must be signed in to change notification settings - Fork 342
fix: preserve RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES to prevent NCCL NVSwitch bugs #2252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -497,8 +497,13 @@ def _create_workers_from_bundle_indices( | |
| "AVAILABLE_PORT_LIST": str(available_ports), | ||
| } | ||
| ) | ||
| # Remove Ray-specific environment variables, let the worker itself set them. | ||
| worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) | ||
| # Preserve RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 to prevent Ray | ||
| # from masking CUDA_VISIBLE_DEVICES per actor. GPU masking triggers NCCL | ||
| # bugs on NVSwitch topologies (H200/P5en, H100/P5) including cuMem import | ||
| # penalty (nccl#1749) and NVLS rank ordering corruption (nccl#1906). | ||
| # Workers use explicit torch.cuda.set_device(local_rank) instead. | ||
| # See: https://github.com/NVIDIA-NeMo/RL/issues/1963 | ||
| worker_env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: DTensor V1 This unconditional # torch==2.8 uses LOCAL_RANK to set the device here
# but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0.
prev_local_rank = os.environ["LOCAL_RANK"]
os.environ["LOCAL_RANK"] = "0"
device_mesh = torch.distributed.device_mesh.init_device_mesh(...)With all GPUs visible, Fix: Remove the |
||
| worker_env_vars.pop("RAY_CLIENT_MODE", None) | ||
| worker_env_vars.pop("RAY_JOB_ID", None) | ||
| worker_env_vars.pop("RAY_LD_PRELOAD", None) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is inaccurate — no production
nemo_rlcode callstorch.cuda.set_device(local_rank). Device binding happens implicitly viainit_device_mesh/ Megatron internals reading theLOCAL_RANKenv var.