diff --git a/scripts/performance/run_script.py b/scripts/performance/run_script.py index f5df41fc3e..2787f8c6f3 100644 --- a/scripts/performance/run_script.py +++ b/scripts/performance/run_script.py @@ -95,6 +95,11 @@ def main(): config_variant=args.config_variant, ) + # Set NCCL env vars for nccl_ub enabled via recipe config (not just CLI). + if getattr(recipe.ddp, "nccl_ub", False): + os.environ.setdefault("NCCL_NVLS_ENABLE", "1") + os.environ.setdefault("NCCL_CTA_POLICY", "1") + # Select forward step function based on the model family name. if args.domain == "vlm": forward_step_func = vlm_forward_step