|
25 | 25 | from typing import List, Mapping, Optional, Sequence, Tuple, Union |
26 | 26 |
|
27 | 27 | import torch |
| 28 | +try: |
| 29 | + from torch.distributed.tensor import DTensor |
| 30 | + has_dtensor = True |
| 31 | +except ImportError: |
| 32 | + has_dtensor = False |
28 | 33 |
|
29 | 34 | from ._types import ParamsT |
30 | 35 | from .adamw import adamw |
@@ -159,21 +164,31 @@ def zeropower_via_newtonschulz( |
159 | 164 | else: |
160 | 165 | X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_(min=eps)) |
161 | 166 |
|
162 | | - # Batched vs unbatched fused MM |
163 | | - mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm |
164 | | - |
165 | | - # Pre-allocate |
166 | | - X = X.contiguous() |
167 | | - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) |
168 | | - B = torch.empty_like(A) |
169 | | - C = torch.empty_like(X) |
170 | | - |
171 | | - # Perform Newton-Schulz iterations |
172 | | - for a, b, c in coeff_sequence: |
173 | | - mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT |
174 | | - mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A |
175 | | - mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X |
176 | | - X, C = C, X # swap refs to avoid copy |
| 167 | + is_dtensor = has_dtensor and isinstance(G, DTensor) |
| 168 | + if is_dtensor: |
| 169 | + # Basic, DTensor-friendly Newton-Schulz |
| 170 | + for a, b, c in coeff_sequence: |
| 171 | + A = X @ X.mT |
| 172 | + B = b * A + c * (A @ A) |
| 173 | + X = a * X + (B @ X) |
| 174 | + else: |
| 175 | + # Fast prealloc/out= path |
| 176 | + |
| 177 | + # Batched vs unbatched fused MM |
| 178 | + mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm |
| 179 | + |
| 180 | + # Pre-allocate |
| 181 | + X = X.contiguous() |
| 182 | + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) |
| 183 | + B = torch.empty_like(A) |
| 184 | + C = torch.empty_like(X) |
| 185 | + |
| 186 | + # Perform Newton-Schulz iterations |
| 187 | + for a, b, c in coeff_sequence: |
| 188 | + mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT |
| 189 | + mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A |
| 190 | + mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X |
| 191 | + X, C = C, X # swap refs to avoid copy |
177 | 192 |
|
178 | 193 | if transposed: |
179 | 194 | X = X.mT |
|
0 commit comments