Skip to content

[BUG] Fails to substute an expression with a variable in indexing code #1642

@iclementine

Description

@iclementine

Required prerequisites

What version of TileLang are you using?

0.1.6.post2

System information

system information

3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0] linux
0.1.6.post2
2.7.0a0+79aa17489c.nv25.04

Problem description

When I try to substitute a common subexpression with a variable in the indexing code, the compiler gives me vague errors. But substitute it with non-indexing code seems to work fine.

Reproducible example code

The Python snippets:

m, h, d, n = 1024, 16, 128, 32768
query = torch.randn((m, h, d), dtype=torch.bfloat16, device="cuda")
key = torch.randn((n, d), dtype=torch.bfloat16, device="cuda")
weight = torch.randn((m, h), dtype=torch.float32, device="cuda")
starts = torch.randint(0, 512, (m,), dtype=torch.int32, device="cuda")
ends =  torch.randint(n - 512, n, (m,), dtype=torch.int32, device="cuda")

m, h, d = query.shape
n = key.shape[0]
device = query.device
scores = torch.full((m, n), float("-inf"), dtype=torch.float32, device=device)
tile_n = 128
kernel = indexer_template(m, h, n, d, topk, tile_n)
kernel(query, key, weight, starts, ends, scores)


@tilelang.jit
def indexer_template(m, h, n, d, topk, tile_n):
    tile_m = h

    @T.prim_func
    def indexer_kernel(
        query: T.Tensor[(m, h, d), "bfloat16"],
        key: T.Tensor[(n, d), "bfloat16"],
        weight: T.Tensor[(m, h), "float32"],
        starts: T.Tensor[(m, ), "int32"],
        ends: T.Tensor[(m, ), "int32"],
        scores: T.Tensor[(m, n), "float32"],
    ):
        with T.Kernel(m, threads=128) as bx:
            q_tile_s = T.alloc_shared((tile_m, d), dtype="bfloat16")
            k_tile_s = T.alloc_shared((tile_n, d), dtype="bfloat16")
            score_tile_r = T.alloc_fragment((tile_n, tile_m), dtype="float32") # reverse it
            w_tile_r = T.alloc_fragment((tile_m,), dtype="float32")
            ss_tile_r = T.alloc_fragment((tile_n,), dtype="float32")

            T.copy(query[bx, 0: tile_m, 0: d], q_tile_s)
            T.copy(weight[bx, 0: tile_m] , w_tile_r)

            start = starts[bx]
            end = ends[bx]
            aligned_start = T.floordiv(start, tile_n) * tile_n
            aligned_end = T.ceildiv(end, tile_n) * tile_n
            n_iters = T.ceildiv(aligned_end - aligned_start, tile_n)

            tile_start = T.alloc_var("int32", 0)
            for si in T.serial(n_iters):
                tile_start = aligned_start + si * tile_n # this is the sub expression I want to substitute with a variable
                T.copy(key[tile_start : aligned_start + (si + 1) * tile_n, 0: d], k_tile_s)
                T.gemm(k_tile_s, q_tile_s, score_tile_r, transpose_B=True, clear_accum=True)
                for i, j in T.Parallel(tile_n, tile_m):
                    score_tile_r[i, j] = T.max(0, score_tile_r[i, j]) * w_tile_r[j]

                T.reduce_sum(score_tile_r, ss_tile_r, dim=-1, clear=True) 
                for i in T.Parallel(tile_n):
                    ss_tile_r[i] = T.if_then_else(((tile_start + i) >= start) and ((aligned_start + si * tile_n + i) < end), ss_tile_r[i], float("-inf"))

                T.copy(ss_tile_r, scores[bx, aligned_start + (si * tile_n) : aligned_start + ((si + 1) * tile_n)])
    return indexer_kernel

Traceback

error: Cannot use and / or / not operator to Expr, hint: use tvm.tir.all / tvm.tir.any instead
 --> fused_indexer_topk.py:65:17
    |  
 65 |                  T.copy(key[tile_start : aligned_start + (si + 1) * tile_n, 0: d], k_tile_s)
    |                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

Expected behavior

Expect it to work fine, both for indexing code or non-indexing code.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions