-
Notifications
You must be signed in to change notification settings - Fork 438
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2
System information
system information
3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] linux
0.1.6.post2
2.7.0a0+79aa17489c.nv25.04
Problem description
When I try to substitute a common subexpression with a variable in the indexing code, the compiler gives me vague errors. But substitute it with non-indexing code seems to work fine.
Reproducible example code
The Python snippets:
m, h, d, n = 1024, 16, 128, 32768
query = torch.randn((m, h, d), dtype=torch.bfloat16, device="cuda")
key = torch.randn((n, d), dtype=torch.bfloat16, device="cuda")
weight = torch.randn((m, h), dtype=torch.float32, device="cuda")
starts = torch.randint(0, 512, (m,), dtype=torch.int32, device="cuda")
ends = torch.randint(n - 512, n, (m,), dtype=torch.int32, device="cuda")
m, h, d = query.shape
n = key.shape[0]
device = query.device
scores = torch.full((m, n), float("-inf"), dtype=torch.float32, device=device)
tile_n = 128
kernel = indexer_template(m, h, n, d, topk, tile_n)
kernel(query, key, weight, starts, ends, scores)
@tilelang.jit
def indexer_template(m, h, n, d, topk, tile_n):
tile_m = h
@T.prim_func
def indexer_kernel(
query: T.Tensor[(m, h, d), "bfloat16"],
key: T.Tensor[(n, d), "bfloat16"],
weight: T.Tensor[(m, h), "float32"],
starts: T.Tensor[(m, ), "int32"],
ends: T.Tensor[(m, ), "int32"],
scores: T.Tensor[(m, n), "float32"],
):
with T.Kernel(m, threads=128) as bx:
q_tile_s = T.alloc_shared((tile_m, d), dtype="bfloat16")
k_tile_s = T.alloc_shared((tile_n, d), dtype="bfloat16")
score_tile_r = T.alloc_fragment((tile_n, tile_m), dtype="float32") # reverse it
w_tile_r = T.alloc_fragment((tile_m,), dtype="float32")
ss_tile_r = T.alloc_fragment((tile_n,), dtype="float32")
T.copy(query[bx, 0: tile_m, 0: d], q_tile_s)
T.copy(weight[bx, 0: tile_m] , w_tile_r)
start = starts[bx]
end = ends[bx]
aligned_start = T.floordiv(start, tile_n) * tile_n
aligned_end = T.ceildiv(end, tile_n) * tile_n
n_iters = T.ceildiv(aligned_end - aligned_start, tile_n)
tile_start = T.alloc_var("int32", 0)
for si in T.serial(n_iters):
tile_start = aligned_start + si * tile_n # this is the sub expression I want to substitute with a variable
T.copy(key[tile_start : aligned_start + (si + 1) * tile_n, 0: d], k_tile_s)
T.gemm(k_tile_s, q_tile_s, score_tile_r, transpose_B=True, clear_accum=True)
for i, j in T.Parallel(tile_n, tile_m):
score_tile_r[i, j] = T.max(0, score_tile_r[i, j]) * w_tile_r[j]
T.reduce_sum(score_tile_r, ss_tile_r, dim=-1, clear=True)
for i in T.Parallel(tile_n):
ss_tile_r[i] = T.if_then_else(((tile_start + i) >= start) and ((aligned_start + si * tile_n + i) < end), ss_tile_r[i], float("-inf"))
T.copy(ss_tile_r, scores[bx, aligned_start + (si * tile_n) : aligned_start + ((si + 1) * tile_n)])
return indexer_kernelTraceback
error: Cannot use and / or / not operator to Expr, hint: use tvm.tir.all / tvm.tir.any instead
--> fused_indexer_topk.py:65:17
|
65 | T.copy(key[tile_start : aligned_start + (si + 1) * tile_n, 0: d], k_tile_s)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.Expected behavior
Expect it to work fine, both for indexing code or non-indexing code.
Additional context
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working