Skip to content

Commit 627b579

Browse files
Add an example: mHC residual projection backward (#1758)
* [Example] Add deepseek mHC sinkhorn backward * Remove example_mhc_res.py file from deepseek_mhc examples directory * Remove example_mhc_res.py file from deepseek_mhc examples directory --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 110ef30 commit 627b579

File tree

2 files changed

+287
-0
lines changed

2 files changed

+287
-0
lines changed

docs/spelling_wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
cancelled
2+
dout
23
hsa
34
ist
45
LOD
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# NOTE: This bwd script is not an official upstream script; it is community-written and provided for reference only.
2+
# checkout pr: https://github.com/tile-ai/tilelang/pull/1758
3+
import torch
4+
5+
import tilelang
6+
import tilelang.language as T
7+
from tilelang.autotuner import set_autotune_inputs
8+
from tqdm import trange
9+
10+
11+
dtype = torch.float32
12+
13+
seqlen = 65536
14+
n_stream = 16
15+
iters = 100
16+
repeat = 512
17+
18+
EPS = 1e-10
19+
20+
21+
def sinkhorn_forward(M, iters=20):
22+
P = torch.exp(M)
23+
R = P
24+
25+
for _ in range(iters):
26+
R = R / R.sum(-2, keepdim=True)
27+
R = R / R.sum(-1, keepdim=True)
28+
29+
return R, P
30+
31+
32+
def sinkhorn_bwd_configs(n_stream, seqlen):
33+
"""Generate autotune configurations for different tilesize and threads"""
34+
configs = []
35+
36+
# Explore different tile sizes and thread counts
37+
tilesizes = [1, 2, 4, 8, 16, 32, 64]
38+
thread_counts = [32, 64, 128, 256]
39+
40+
for tilesize in tilesizes:
41+
# Skip if tilesize doesn't divide seqlen evenly (optional constraint)
42+
if seqlen % tilesize != 0:
43+
continue
44+
45+
for threads in thread_counts:
46+
configs.append({"tilesize": tilesize, "threads": threads})
47+
48+
return configs
49+
50+
51+
@tilelang.autotune(
52+
configs=sinkhorn_bwd_configs(n_stream, seqlen),
53+
warmup=4,
54+
rep=repeat,
55+
)
56+
@tilelang.jit(
57+
out_idx=[2],
58+
pass_configs={
59+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
60+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
61+
},
62+
)
63+
def sinkhorn_bwd_implicit_cg(n_stream: int, tilesize: int = 32, threads: int = 128):
64+
seqlen = T.dynamic("seqlen")
65+
tensor_shape = [seqlen, n_stream, n_stream]
66+
dtype = T.float32
67+
68+
@T.macro
69+
def matvec_A(R, x1, x2, buf, y1, y2):
70+
for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream):
71+
buf[i_tile, i, j] = R[i_tile, i, j] * x2[i_tile, j]
72+
T.reduce_sum(buf, y1, dim=-1)
73+
74+
for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream):
75+
buf[i_tile, i, j] = R[i_tile, i, j] * x1[i_tile, i]
76+
T.reduce_sum(buf, y2, dim=-2)
77+
78+
for i_tile, i in T.Parallel(tilesize, n_stream):
79+
y1[i_tile, i] += x1[i_tile, i]
80+
y2[i_tile, i] += x2[i_tile, i]
81+
82+
@T.macro
83+
def dot(x1, x2, y1, y2, buf, out):
84+
for i_tile, i in T.Parallel(tilesize, n_stream):
85+
buf[i_tile, i] = x1[i_tile, i] * y1[i_tile, i] + x2[i_tile, i] * y2[i_tile, i]
86+
87+
T.reduce_sum(buf, out, dim=-1)
88+
89+
@T.prim_func
90+
def main(
91+
out: T.Tensor(tensor_shape, dtype),
92+
dout: T.Tensor(tensor_shape, dtype),
93+
res: T.Tensor(tensor_shape, dtype),
94+
):
95+
with T.Kernel(T.ceildiv(seqlen, tilesize), threads=threads) as i_seq:
96+
R = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype)
97+
dR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype)
98+
RdR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype)
99+
res_tile = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
100+
b1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
101+
b2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
102+
x1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
103+
x2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
104+
r1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
105+
r2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
106+
p1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
107+
p2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
108+
alpha = T.alloc_fragment([tilesize, n_stream], dtype=dtype)
109+
beta = T.alloc_fragment([tilesize, n_stream], dtype=dtype)
110+
r_normsq = T.alloc_fragment([tilesize], dtype=dtype)
111+
r_new_normsq = T.alloc_fragment([tilesize], dtype=dtype)
112+
Ap1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
113+
Ap2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
114+
pAp = T.alloc_fragment([tilesize], dtype=dtype)
115+
116+
# Buffers for intermediate results
117+
buf1 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
118+
buf2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
119+
120+
T.copy(out[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], R)
121+
T.copy(dout[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], dR)
122+
123+
for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream):
124+
RdR[i_tile, i_nx, i_ny] = R[i_tile, i_nx, i_ny] * dR[i_tile, i_nx, i_ny]
125+
126+
T.reduce_sum(RdR, b1, dim=-1)
127+
T.reduce_sum(RdR, b2, dim=-2)
128+
129+
T.fill(x1, 0.0)
130+
T.fill(x2, 0.0)
131+
132+
matvec_A(R, x1, x2, buf1, r1, r2)
133+
134+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
135+
r1[i_tile, i_n] = b1[i_tile, i_n] - r1[i_tile, i_n]
136+
137+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
138+
r2[i_tile, i_n] = b2[i_tile, i_n] - r2[i_tile, i_n]
139+
140+
T.copy(r1, p1)
141+
T.copy(r2, p2)
142+
143+
dot(r1, r2, r1, r2, buf2, r_normsq)
144+
145+
# Conjugate gradient: iteration starts
146+
for _ in T.serial(2 * n_stream):
147+
matvec_A(R, p1, p2, buf1, Ap1, Ap2)
148+
149+
dot(p1, p2, Ap1, Ap2, buf2, pAp)
150+
151+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
152+
# VERY important to avoid divide by zero
153+
alpha[i_tile, i_n] = r_normsq[i_tile] / (pAp[i_tile] + EPS)
154+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
155+
x1[i_tile, i_n] += alpha[i_tile, i_n] * p1[i_tile, i_n]
156+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
157+
x2[i_tile, i_n] += alpha[i_tile, i_n] * p2[i_tile, i_n]
158+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
159+
r1[i_tile, i_n] -= alpha[i_tile, i_n] * Ap1[i_tile, i_n]
160+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
161+
r2[i_tile, i_n] -= alpha[i_tile, i_n] * Ap2[i_tile, i_n]
162+
163+
dot(r1, r2, r1, r2, buf2, r_new_normsq)
164+
165+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
166+
# not very important to avoid divide by zero, but it's good to have it
167+
beta[i_tile, i_n] = r_new_normsq[i_tile] / (r_normsq[i_tile] + EPS)
168+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
169+
p1[i_tile, i_n] = r1[i_tile, i_n] + beta[i_tile, i_n] * p1[i_tile, i_n]
170+
for i_tile, i_n in T.Parallel(tilesize, n_stream):
171+
p2[i_tile, i_n] = r2[i_tile, i_n] + beta[i_tile, i_n] * p2[i_tile, i_n]
172+
173+
T.copy(r_new_normsq, r_normsq)
174+
# Conjugate gradient: iteration ends
175+
176+
for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream):
177+
res_tile[i_tile, i_nx, i_ny] = (dR[i_tile, i_nx, i_ny] - x1[i_tile, i_nx] - x2[i_tile, i_ny]) * R[i_tile, i_nx, i_ny]
178+
179+
T.copy(res_tile, res[i_seq * tilesize : (i_seq + 1) * tilesize, :, :])
180+
181+
return main
182+
183+
184+
def main():
185+
print("Autotuning TileLang kernel for sinkhorn backward pass")
186+
print(f"{seqlen = }")
187+
print(f"{n_stream = }")
188+
print(f"{iters = }")
189+
print(f"{repeat = }")
190+
191+
######################################################################
192+
# Variable
193+
######################################################################
194+
dist = torch.distributions.uniform.Uniform(0.0, 4.0)
195+
device = torch.device("cuda")
196+
M = dist.sample((seqlen, n_stream, n_stream)).to(device)
197+
M.requires_grad_()
198+
199+
######################################################################
200+
# Shared forward + one shared loss weight
201+
######################################################################
202+
R, P = sinkhorn_forward(M, iters)
203+
loss_weight = torch.randn_like(R)
204+
205+
######################################################################
206+
# Method A: Autograd (reference)
207+
######################################################################
208+
loss_a = (R * loss_weight).sum()
209+
loss_a.backward()
210+
grad_M_autograd = M.grad.detach().clone()
211+
212+
######################################################################
213+
# Method B: Implicit differentiation with autotuning
214+
######################################################################
215+
grad_R = loss_weight
216+
217+
print("\n" + "=" * 60)
218+
print("Starting autotuning...")
219+
print("=" * 60)
220+
221+
# Set autotune inputs
222+
with set_autotune_inputs(R, grad_R):
223+
kernel = sinkhorn_bwd_implicit_cg(n_stream)
224+
print(kernel.get_kernel_source())
225+
print("\n" + "=" * 60)
226+
print("Autotuning completed! Running with best configuration...")
227+
print("=" * 60)
228+
229+
# Warmup and timing with best config
230+
a = torch.randn(8192, 8192, device=device)
231+
for _ in trange(4, desc="Warmup"):
232+
_ = a @ a
233+
grad_M_implicit = kernel(R, grad_R)
234+
torch.cuda.synchronize()
235+
236+
# Timing
237+
start_event = torch.cuda.Event(enable_timing=True)
238+
end_event = torch.cuda.Event(enable_timing=True)
239+
240+
torch.cuda.synchronize()
241+
start_event.record()
242+
243+
for _ in range(repeat):
244+
grad_M_implicit = kernel(R, grad_R)
245+
246+
end_event.record()
247+
torch.cuda.synchronize()
248+
249+
elapsed_time_ms = start_event.elapsed_time(end_event)
250+
251+
print(f"\nKernel execution time ({repeat = }): {elapsed_time_ms:.3f} ms")
252+
print(f"Average time per iteration: {elapsed_time_ms / repeat:.3f} ms")
253+
254+
######################################################################
255+
# Compare
256+
######################################################################
257+
g1 = grad_M_autograd
258+
g2 = grad_M_implicit
259+
260+
abs_diff = (g1 - g2).abs()
261+
# Use max of absolute values for more stable relative error
262+
rel_diff = abs_diff / (torch.maximum(g1.abs(), g2.abs()) + 1e-8)
263+
264+
print("\n" + "=" * 60)
265+
print("Comparison of gradients dL/dM")
266+
print("=" * 60)
267+
268+
def format_list(ls):
269+
return [f"{x:.2e}" for x in ls]
270+
271+
MAE = abs_diff.mean(dim=(-1, -2)).tolist()
272+
max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist()
273+
mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist()
274+
max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist()
275+
276+
print(f"Max MAE = {max(MAE):.6e}")
277+
print(f"Max max_abs_diff = {max(max_abs_diff):.6e}")
278+
print(f"Max mean_rel_diff = {max(mean_rel_diff):.6e}")
279+
print(f"Max max_rel_diff = {max(max_rel_diff):.6e}")
280+
281+
print("\nGrad (autograd) sample:\n", g1[0, :3, :3])
282+
print("\nGrad (implicit) sample:\n", g2[0, :3, :3])
283+
284+
285+
if __name__ == "__main__":
286+
main()

0 commit comments

Comments
 (0)