|
| 1 | +# NOTE: This bwd script is not an official upstream script; it is community-written and provided for reference only. |
| 2 | +# checkout pr: https://github.com/tile-ai/tilelang/pull/1758 |
| 3 | +import torch |
| 4 | + |
| 5 | +import tilelang |
| 6 | +import tilelang.language as T |
| 7 | +from tilelang.autotuner import set_autotune_inputs |
| 8 | +from tqdm import trange |
| 9 | + |
| 10 | + |
| 11 | +dtype = torch.float32 |
| 12 | + |
| 13 | +seqlen = 65536 |
| 14 | +n_stream = 16 |
| 15 | +iters = 100 |
| 16 | +repeat = 512 |
| 17 | + |
| 18 | +EPS = 1e-10 |
| 19 | + |
| 20 | + |
| 21 | +def sinkhorn_forward(M, iters=20): |
| 22 | + P = torch.exp(M) |
| 23 | + R = P |
| 24 | + |
| 25 | + for _ in range(iters): |
| 26 | + R = R / R.sum(-2, keepdim=True) |
| 27 | + R = R / R.sum(-1, keepdim=True) |
| 28 | + |
| 29 | + return R, P |
| 30 | + |
| 31 | + |
| 32 | +def sinkhorn_bwd_configs(n_stream, seqlen): |
| 33 | + """Generate autotune configurations for different tilesize and threads""" |
| 34 | + configs = [] |
| 35 | + |
| 36 | + # Explore different tile sizes and thread counts |
| 37 | + tilesizes = [1, 2, 4, 8, 16, 32, 64] |
| 38 | + thread_counts = [32, 64, 128, 256] |
| 39 | + |
| 40 | + for tilesize in tilesizes: |
| 41 | + # Skip if tilesize doesn't divide seqlen evenly (optional constraint) |
| 42 | + if seqlen % tilesize != 0: |
| 43 | + continue |
| 44 | + |
| 45 | + for threads in thread_counts: |
| 46 | + configs.append({"tilesize": tilesize, "threads": threads}) |
| 47 | + |
| 48 | + return configs |
| 49 | + |
| 50 | + |
| 51 | +@tilelang.autotune( |
| 52 | + configs=sinkhorn_bwd_configs(n_stream, seqlen), |
| 53 | + warmup=4, |
| 54 | + rep=repeat, |
| 55 | +) |
| 56 | +@tilelang.jit( |
| 57 | + out_idx=[2], |
| 58 | + pass_configs={ |
| 59 | + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, |
| 60 | + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, |
| 61 | + }, |
| 62 | +) |
| 63 | +def sinkhorn_bwd_implicit_cg(n_stream: int, tilesize: int = 32, threads: int = 128): |
| 64 | + seqlen = T.dynamic("seqlen") |
| 65 | + tensor_shape = [seqlen, n_stream, n_stream] |
| 66 | + dtype = T.float32 |
| 67 | + |
| 68 | + @T.macro |
| 69 | + def matvec_A(R, x1, x2, buf, y1, y2): |
| 70 | + for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream): |
| 71 | + buf[i_tile, i, j] = R[i_tile, i, j] * x2[i_tile, j] |
| 72 | + T.reduce_sum(buf, y1, dim=-1) |
| 73 | + |
| 74 | + for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream): |
| 75 | + buf[i_tile, i, j] = R[i_tile, i, j] * x1[i_tile, i] |
| 76 | + T.reduce_sum(buf, y2, dim=-2) |
| 77 | + |
| 78 | + for i_tile, i in T.Parallel(tilesize, n_stream): |
| 79 | + y1[i_tile, i] += x1[i_tile, i] |
| 80 | + y2[i_tile, i] += x2[i_tile, i] |
| 81 | + |
| 82 | + @T.macro |
| 83 | + def dot(x1, x2, y1, y2, buf, out): |
| 84 | + for i_tile, i in T.Parallel(tilesize, n_stream): |
| 85 | + buf[i_tile, i] = x1[i_tile, i] * y1[i_tile, i] + x2[i_tile, i] * y2[i_tile, i] |
| 86 | + |
| 87 | + T.reduce_sum(buf, out, dim=-1) |
| 88 | + |
| 89 | + @T.prim_func |
| 90 | + def main( |
| 91 | + out: T.Tensor(tensor_shape, dtype), |
| 92 | + dout: T.Tensor(tensor_shape, dtype), |
| 93 | + res: T.Tensor(tensor_shape, dtype), |
| 94 | + ): |
| 95 | + with T.Kernel(T.ceildiv(seqlen, tilesize), threads=threads) as i_seq: |
| 96 | + R = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) |
| 97 | + dR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) |
| 98 | + RdR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) |
| 99 | + res_tile = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) |
| 100 | + b1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 101 | + b2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 102 | + x1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 103 | + x2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 104 | + r1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 105 | + r2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 106 | + p1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 107 | + p2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 108 | + alpha = T.alloc_fragment([tilesize, n_stream], dtype=dtype) |
| 109 | + beta = T.alloc_fragment([tilesize, n_stream], dtype=dtype) |
| 110 | + r_normsq = T.alloc_fragment([tilesize], dtype=dtype) |
| 111 | + r_new_normsq = T.alloc_fragment([tilesize], dtype=dtype) |
| 112 | + Ap1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 113 | + Ap2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 114 | + pAp = T.alloc_fragment([tilesize], dtype=dtype) |
| 115 | + |
| 116 | + # Buffers for intermediate results |
| 117 | + buf1 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) |
| 118 | + buf2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) |
| 119 | + |
| 120 | + T.copy(out[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], R) |
| 121 | + T.copy(dout[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], dR) |
| 122 | + |
| 123 | + for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream): |
| 124 | + RdR[i_tile, i_nx, i_ny] = R[i_tile, i_nx, i_ny] * dR[i_tile, i_nx, i_ny] |
| 125 | + |
| 126 | + T.reduce_sum(RdR, b1, dim=-1) |
| 127 | + T.reduce_sum(RdR, b2, dim=-2) |
| 128 | + |
| 129 | + T.fill(x1, 0.0) |
| 130 | + T.fill(x2, 0.0) |
| 131 | + |
| 132 | + matvec_A(R, x1, x2, buf1, r1, r2) |
| 133 | + |
| 134 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 135 | + r1[i_tile, i_n] = b1[i_tile, i_n] - r1[i_tile, i_n] |
| 136 | + |
| 137 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 138 | + r2[i_tile, i_n] = b2[i_tile, i_n] - r2[i_tile, i_n] |
| 139 | + |
| 140 | + T.copy(r1, p1) |
| 141 | + T.copy(r2, p2) |
| 142 | + |
| 143 | + dot(r1, r2, r1, r2, buf2, r_normsq) |
| 144 | + |
| 145 | + # Conjugate gradient: iteration starts |
| 146 | + for _ in T.serial(2 * n_stream): |
| 147 | + matvec_A(R, p1, p2, buf1, Ap1, Ap2) |
| 148 | + |
| 149 | + dot(p1, p2, Ap1, Ap2, buf2, pAp) |
| 150 | + |
| 151 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 152 | + # VERY important to avoid divide by zero |
| 153 | + alpha[i_tile, i_n] = r_normsq[i_tile] / (pAp[i_tile] + EPS) |
| 154 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 155 | + x1[i_tile, i_n] += alpha[i_tile, i_n] * p1[i_tile, i_n] |
| 156 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 157 | + x2[i_tile, i_n] += alpha[i_tile, i_n] * p2[i_tile, i_n] |
| 158 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 159 | + r1[i_tile, i_n] -= alpha[i_tile, i_n] * Ap1[i_tile, i_n] |
| 160 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 161 | + r2[i_tile, i_n] -= alpha[i_tile, i_n] * Ap2[i_tile, i_n] |
| 162 | + |
| 163 | + dot(r1, r2, r1, r2, buf2, r_new_normsq) |
| 164 | + |
| 165 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 166 | + # not very important to avoid divide by zero, but it's good to have it |
| 167 | + beta[i_tile, i_n] = r_new_normsq[i_tile] / (r_normsq[i_tile] + EPS) |
| 168 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 169 | + p1[i_tile, i_n] = r1[i_tile, i_n] + beta[i_tile, i_n] * p1[i_tile, i_n] |
| 170 | + for i_tile, i_n in T.Parallel(tilesize, n_stream): |
| 171 | + p2[i_tile, i_n] = r2[i_tile, i_n] + beta[i_tile, i_n] * p2[i_tile, i_n] |
| 172 | + |
| 173 | + T.copy(r_new_normsq, r_normsq) |
| 174 | + # Conjugate gradient: iteration ends |
| 175 | + |
| 176 | + for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream): |
| 177 | + res_tile[i_tile, i_nx, i_ny] = (dR[i_tile, i_nx, i_ny] - x1[i_tile, i_nx] - x2[i_tile, i_ny]) * R[i_tile, i_nx, i_ny] |
| 178 | + |
| 179 | + T.copy(res_tile, res[i_seq * tilesize : (i_seq + 1) * tilesize, :, :]) |
| 180 | + |
| 181 | + return main |
| 182 | + |
| 183 | + |
| 184 | +def main(): |
| 185 | + print("Autotuning TileLang kernel for sinkhorn backward pass") |
| 186 | + print(f"{seqlen = }") |
| 187 | + print(f"{n_stream = }") |
| 188 | + print(f"{iters = }") |
| 189 | + print(f"{repeat = }") |
| 190 | + |
| 191 | + ###################################################################### |
| 192 | + # Variable |
| 193 | + ###################################################################### |
| 194 | + dist = torch.distributions.uniform.Uniform(0.0, 4.0) |
| 195 | + device = torch.device("cuda") |
| 196 | + M = dist.sample((seqlen, n_stream, n_stream)).to(device) |
| 197 | + M.requires_grad_() |
| 198 | + |
| 199 | + ###################################################################### |
| 200 | + # Shared forward + one shared loss weight |
| 201 | + ###################################################################### |
| 202 | + R, P = sinkhorn_forward(M, iters) |
| 203 | + loss_weight = torch.randn_like(R) |
| 204 | + |
| 205 | + ###################################################################### |
| 206 | + # Method A: Autograd (reference) |
| 207 | + ###################################################################### |
| 208 | + loss_a = (R * loss_weight).sum() |
| 209 | + loss_a.backward() |
| 210 | + grad_M_autograd = M.grad.detach().clone() |
| 211 | + |
| 212 | + ###################################################################### |
| 213 | + # Method B: Implicit differentiation with autotuning |
| 214 | + ###################################################################### |
| 215 | + grad_R = loss_weight |
| 216 | + |
| 217 | + print("\n" + "=" * 60) |
| 218 | + print("Starting autotuning...") |
| 219 | + print("=" * 60) |
| 220 | + |
| 221 | + # Set autotune inputs |
| 222 | + with set_autotune_inputs(R, grad_R): |
| 223 | + kernel = sinkhorn_bwd_implicit_cg(n_stream) |
| 224 | + print(kernel.get_kernel_source()) |
| 225 | + print("\n" + "=" * 60) |
| 226 | + print("Autotuning completed! Running with best configuration...") |
| 227 | + print("=" * 60) |
| 228 | + |
| 229 | + # Warmup and timing with best config |
| 230 | + a = torch.randn(8192, 8192, device=device) |
| 231 | + for _ in trange(4, desc="Warmup"): |
| 232 | + _ = a @ a |
| 233 | + grad_M_implicit = kernel(R, grad_R) |
| 234 | + torch.cuda.synchronize() |
| 235 | + |
| 236 | + # Timing |
| 237 | + start_event = torch.cuda.Event(enable_timing=True) |
| 238 | + end_event = torch.cuda.Event(enable_timing=True) |
| 239 | + |
| 240 | + torch.cuda.synchronize() |
| 241 | + start_event.record() |
| 242 | + |
| 243 | + for _ in range(repeat): |
| 244 | + grad_M_implicit = kernel(R, grad_R) |
| 245 | + |
| 246 | + end_event.record() |
| 247 | + torch.cuda.synchronize() |
| 248 | + |
| 249 | + elapsed_time_ms = start_event.elapsed_time(end_event) |
| 250 | + |
| 251 | + print(f"\nKernel execution time ({repeat = }): {elapsed_time_ms:.3f} ms") |
| 252 | + print(f"Average time per iteration: {elapsed_time_ms / repeat:.3f} ms") |
| 253 | + |
| 254 | + ###################################################################### |
| 255 | + # Compare |
| 256 | + ###################################################################### |
| 257 | + g1 = grad_M_autograd |
| 258 | + g2 = grad_M_implicit |
| 259 | + |
| 260 | + abs_diff = (g1 - g2).abs() |
| 261 | + # Use max of absolute values for more stable relative error |
| 262 | + rel_diff = abs_diff / (torch.maximum(g1.abs(), g2.abs()) + 1e-8) |
| 263 | + |
| 264 | + print("\n" + "=" * 60) |
| 265 | + print("Comparison of gradients dL/dM") |
| 266 | + print("=" * 60) |
| 267 | + |
| 268 | + def format_list(ls): |
| 269 | + return [f"{x:.2e}" for x in ls] |
| 270 | + |
| 271 | + MAE = abs_diff.mean(dim=(-1, -2)).tolist() |
| 272 | + max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist() |
| 273 | + mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist() |
| 274 | + max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist() |
| 275 | + |
| 276 | + print(f"Max MAE = {max(MAE):.6e}") |
| 277 | + print(f"Max max_abs_diff = {max(max_abs_diff):.6e}") |
| 278 | + print(f"Max mean_rel_diff = {max(mean_rel_diff):.6e}") |
| 279 | + print(f"Max max_rel_diff = {max(max_rel_diff):.6e}") |
| 280 | + |
| 281 | + print("\nGrad (autograd) sample:\n", g1[0, :3, :3]) |
| 282 | + print("\nGrad (implicit) sample:\n", g2[0, :3, :3]) |
| 283 | + |
| 284 | + |
| 285 | +if __name__ == "__main__": |
| 286 | + main() |
0 commit comments