2525from typing import List , Mapping , Optional , Sequence , Tuple , Union
2626
2727import torch
28+ try :
29+ from torch .distributed .tensor import DTensor
30+ has_dtensor = True
31+ except ImportError :
32+ has_dtensor = False
2833
2934from ._types import ParamsT
3035from .adamw import adamw
@@ -145,7 +150,9 @@ def zeropower_via_newtonschulz(
145150 if scale_eps :
146151 eps = scale_eps_for_ns (eps , G .shape )
147152
148- X = G .to (dtype = dtype , copy = True )
153+ is_dtensor = has_dtensor and isinstance (G , DTensor )
154+
155+ X = G .to (dtype = dtype ) if is_dtensor else G .to (dtype = dtype , copy = True )
149156
150157 # Transpose if needed (operate on dimension with fewer elements)
151158 transposed = X .size (- 2 ) > X .size (- 1 )
@@ -159,21 +166,30 @@ def zeropower_via_newtonschulz(
159166 else :
160167 X .div_ (X .norm (2 , dim = (- 2 , - 1 ), keepdim = True ).mul (safety_factor ).clamp_ (min = eps ))
161168
162- # Batched vs unbatched fused MM
163- mm_fn = torch .baddbmm if X .ndim > 2 else torch .addmm
164-
165- # Pre-allocate
166- X = X .contiguous ()
167- A = torch .empty ((* X .shape [:- 1 ], X .size (- 2 )), device = X .device , dtype = X .dtype )
168- B = torch .empty_like (A )
169- C = torch .empty_like (X )
170-
171- # Perform Newton-Schulz iterations
172- for a , b , c in coeff_sequence :
173- mm_fn (A , X , X .mT , beta = 0.0 , alpha = 1.0 , out = A ) # A = X @ X.mT
174- mm_fn (A , A , A , beta = b , alpha = c , out = B ) # B = b * A + c * A @ A
175- mm_fn (X , B , X , beta = a , alpha = 1.0 , out = C ) # C = a * X + B @ X
176- X , C = C , X # swap refs to avoid copy
169+ if is_dtensor :
170+ # Basic, DTensor-friendly Newton-Schulz
171+ for a , b , c in coeff_sequence :
172+ A = X @ X .mT
173+ B = b * A + c * (A @ A )
174+ X = a * X + (B @ X )
175+ else :
176+ # Fast prealloc/out= path
177+
178+ # Batched vs unbatched fused MM
179+ mm_fn = torch .baddbmm if X .ndim > 2 else torch .addmm
180+
181+ # Pre-allocate
182+ X = X .contiguous ()
183+ A = torch .empty ((* X .shape [:- 1 ], X .size (- 2 )), device = X .device , dtype = X .dtype )
184+ B = torch .empty_like (A )
185+ C = torch .empty_like (X )
186+
187+ # Perform Newton-Schulz iterations
188+ for a , b , c in coeff_sequence :
189+ mm_fn (A , X , X .mT , beta = 0.0 , alpha = 1.0 , out = A ) # A = X @ X.mT
190+ mm_fn (A , A , A , beta = b , alpha = c , out = B ) # B = b * A + c * A @ A
191+ mm_fn (X , B , X , beta = a , alpha = 1.0 , out = C ) # C = a * X + B @ X
192+ X , C = C , X # swap refs to avoid copy
177193
178194 if transposed :
179195 X = X .mT
0 commit comments