Skip to content

Commit 4b727cc

Browse files
committed
Add DTensor compatible NS impl for Muon
1 parent feed46c commit 4b727cc

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

timm/optim/_optim_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,13 +846,15 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
846846
name='kron',
847847
opt_class=Kron,
848848
description='PSGD optimizer with Kronecker-factored preconditioner',
849+
has_eps=False,
849850
has_momentum=True,
850851
),
851852
OptimInfo(
852853
name='kronw',
853854
opt_class=Kron,
854855
description='PSGD optimizer with Kronecker-factored preconditioner and decoupled weight decay',
855856
has_momentum=True,
857+
has_eps=False,
856858
defaults={'decoupled_decay': True}
857859
),
858860
OptimInfo(

timm/optim/muon.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from typing import List, Mapping, Optional, Sequence, Tuple, Union
2626

2727
import torch
28+
try:
29+
from torch.distributed.tensor import DTensor
30+
has_dtensor = True
31+
except ImportError:
32+
has_dtensor = False
2833

2934
from ._types import ParamsT
3035
from .adamw import adamw
@@ -159,21 +164,31 @@ def zeropower_via_newtonschulz(
159164
else:
160165
X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_(min=eps))
161166

162-
# Batched vs unbatched fused MM
163-
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
164-
165-
# Pre-allocate
166-
X = X.contiguous()
167-
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
168-
B = torch.empty_like(A)
169-
C = torch.empty_like(X)
170-
171-
# Perform Newton-Schulz iterations
172-
for a, b, c in coeff_sequence:
173-
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
174-
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
175-
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
176-
X, C = C, X # swap refs to avoid copy
167+
is_dtensor = has_dtensor and isinstance(G, DTensor)
168+
if is_dtensor:
169+
# Basic, DTensor-friendly Newton-Schulz
170+
for a, b, c in coeff_sequence:
171+
A = X @ X.mT
172+
B = b * A + c * (A @ A)
173+
X = a * X + (B @ X)
174+
else:
175+
# Fast prealloc/out= path
176+
177+
# Batched vs unbatched fused MM
178+
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
179+
180+
# Pre-allocate
181+
X = X.contiguous()
182+
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
183+
B = torch.empty_like(A)
184+
C = torch.empty_like(X)
185+
186+
# Perform Newton-Schulz iterations
187+
for a, b, c in coeff_sequence:
188+
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
189+
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
190+
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
191+
X, C = C, X # swap refs to avoid copy
177192

178193
if transposed:
179194
X = X.mT

0 commit comments

Comments
 (0)