Skip to content

Commit a86c4f0

Browse files
committed
Add DTensor compatible NS impl for Muon
1 parent feed46c commit a86c4f0

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

timm/optim/muon.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from typing import List, Mapping, Optional, Sequence, Tuple, Union
2626

2727
import torch
28+
try:
29+
from torch.distributed.tensor import DTensor
30+
has_dtensor = True
31+
except ImportError:
32+
has_dtensor = False
2833

2934
from ._types import ParamsT
3035
from .adamw import adamw
@@ -145,7 +150,9 @@ def zeropower_via_newtonschulz(
145150
if scale_eps:
146151
eps = scale_eps_for_ns(eps, G.shape)
147152

148-
X = G.to(dtype=dtype, copy=True)
153+
is_dtensor = has_dtensor and isinstance(G, DTensor)
154+
155+
X = G.to(dtype=dtype) if is_dtensor else G.to(dtype=dtype, copy=True)
149156

150157
# Transpose if needed (operate on dimension with fewer elements)
151158
transposed = X.size(-2) > X.size(-1)
@@ -159,21 +166,30 @@ def zeropower_via_newtonschulz(
159166
else:
160167
X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_(min=eps))
161168

162-
# Batched vs unbatched fused MM
163-
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
164-
165-
# Pre-allocate
166-
X = X.contiguous()
167-
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
168-
B = torch.empty_like(A)
169-
C = torch.empty_like(X)
170-
171-
# Perform Newton-Schulz iterations
172-
for a, b, c in coeff_sequence:
173-
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
174-
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
175-
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
176-
X, C = C, X # swap refs to avoid copy
169+
if is_dtensor:
170+
# Basic, DTensor-friendly Newton-Schulz
171+
for a, b, c in coeff_sequence:
172+
A = X @ X.mT
173+
B = b * A + c * (A @ A)
174+
X = a * X + (B @ X)
175+
else:
176+
# Fast prealloc/out= path
177+
178+
# Batched vs unbatched fused MM
179+
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
180+
181+
# Pre-allocate
182+
X = X.contiguous()
183+
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
184+
B = torch.empty_like(A)
185+
C = torch.empty_like(X)
186+
187+
# Perform Newton-Schulz iterations
188+
for a, b, c in coeff_sequence:
189+
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
190+
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
191+
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
192+
X, C = C, X # swap refs to avoid copy
177193

178194
if transposed:
179195
X = X.mT

0 commit comments

Comments
 (0)