Skip to content

Commit 475eea1

Browse files
committed
[Feature] Add TIR builtins for warp-level vote and block-level predicate sync
1 parent 41b2552 commit 475eea1

File tree

4 files changed

+505
-0
lines changed

4 files changed

+505
-0
lines changed

docs/programming_guides/instructions.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,32 @@ Annotation helpers
139139
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.
140140

141141
Synchronization helpers
142+
- `T.sync_threads([barrier_id, arrive_count])`: Block-wide barrier (`__syncthreads()`).
143+
- `T.sync_warp([mask])`: Warp-wide barrier (`__syncwarp([mask])`).
144+
- `T.sync_grid()`: Cooperative grid barrier (requires cooperative launch).
142145
- `T.pdl_trigger()`: Signal programmatic launch completion for the current kernel.
143146
- `T.pdl_sync()`: Wait until kernel dependencies are satisfied.
144147

148+
Warp-vote / warp-ballot (CUDA ≥ 9 / HIP)
149+
- `T.any_sync(mask, predicate)``int32`: Non-zero if ANY lane in `mask` has non-zero predicate (`__any_sync`).
150+
- `T.all_sync(mask, predicate)``int32`: Non-zero if ALL lanes in `mask` have non-zero predicate (`__all_sync`).
151+
- `T.ballot_sync(mask, predicate)``uint32`: Bitmask of lanes in `mask` with non-zero predicate (`__ballot_sync`).
152+
- `T.ballot(predicate)``uint32`: Full-warp ballot (mask = `0xFFFFFFFF`); equivalent to `T.ballot_sync(0xFFFFFFFF, pred)`.
153+
- `T.activemask()``uint32`: Bitmask of currently active (non-exited) lanes (`__activemask`).
154+
155+
Block-wide predicated sync
156+
- `T.syncthreads_count(predicate)``int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
157+
- `T.syncthreads_and(predicate)``int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
158+
- `T.syncthreads_or(predicate)``int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
159+
160+
Warp-shuffle (intra-warp data exchange)
161+
- `T.shfl_sync(mask, value, src_lane[, width])`: Broadcast value from `src_lane` to all lanes (`__shfl_sync`).
162+
- `T.shfl_xor(value, offset)`: XOR-swap across lanes (`__shfl_xor_sync`, full mask).
163+
- `T.shfl_down(value, offset)`: Shift down by `offset` lanes (`__shfl_down_sync`, full mask).
164+
- `T.shfl_up(value, offset)`: Shift up by `offset` lanes (`__shfl_up_sync`, full mask).
165+
166+
> **Note on HIP:** `any_sync`/`all_sync` ignore the mask and call `__any`/`__all` directly. `ballot_sync`, `ballot`, and `activemask` internally call `__ballot` (which returns `uint64` on a 64-thread wavefront) and cast the result to `uint32`. `syncthreads_count/and/or` have identical signatures on both platforms.
167+
145168
Atomics
146169
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
147170
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
"""Tests for warp-vote / warp-ballot / block-sync-with-predicate intrinsics.
2+
3+
Covered intrinsics
4+
------------------
5+
T.any_sync – __any_sync / __any (HIP)
6+
T.all_sync – __all_sync / __all (HIP)
7+
T.ballot_sync – __ballot_sync / __ballot cast to uint32 (HIP)
8+
T.ballot – ballot with full-warp mask / __ballot (HIP)
9+
T.activemask – __activemask / __ballot(1) cast to uint32 (HIP)
10+
T.syncthreads_count – __syncthreads_count
11+
T.syncthreads_and – __syncthreads_and
12+
T.syncthreads_or – __syncthreads_or
13+
"""
14+
15+
import tilelang
16+
import tilelang.language as T
17+
import torch
18+
import tilelang.testing
19+
20+
21+
# ---------------------------------------------------------------------------
22+
# any_sync
23+
# ---------------------------------------------------------------------------
24+
25+
26+
@tilelang.jit
27+
def kernel_any_sync():
28+
"""Lane 0 writes 1 to A; all lanes use any_sync to see if any lane wrote."""
29+
30+
@T.prim_func
31+
def main(
32+
A: T.Tensor((1,), "int32"),
33+
B: T.Tensor((32,), "int32"),
34+
):
35+
with T.Kernel(1, threads=32):
36+
tx = T.get_thread_binding()
37+
val = T.any_sync(0xFFFFFFFF, tx == 0)
38+
B[tx] = val
39+
40+
return main
41+
42+
43+
@tilelang.testing.requires_cuda
44+
def test_any_sync():
45+
a = torch.zeros((1,), device="cuda", dtype=torch.int32)
46+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
47+
kernel = kernel_any_sync()
48+
src = kernel.get_kernel_source()
49+
assert "__any_sync" in src or "__any" in src, f"Expected __any_sync/__any in source:\n{src}"
50+
kernel(a, b)
51+
# any lane (lane 0) has predicate==1 → result must be non-zero for all lanes
52+
assert torch.all(b != 0), f"Expected all non-zero, got {b}"
53+
54+
55+
# ---------------------------------------------------------------------------
56+
# all_sync
57+
# ---------------------------------------------------------------------------
58+
59+
60+
@tilelang.jit
61+
def kernel_all_sync():
62+
"""All lanes always pass predicate 1 → all_sync should return non-zero."""
63+
64+
@T.prim_func
65+
def main(
66+
B: T.Tensor((32,), "int32"),
67+
):
68+
with T.Kernel(1, threads=32):
69+
tx = T.get_thread_binding()
70+
val = T.all_sync(0xFFFFFFFF, 1)
71+
B[tx] = val
72+
73+
return main
74+
75+
76+
@tilelang.testing.requires_cuda
77+
def test_all_sync():
78+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
79+
kernel = kernel_all_sync()
80+
src = kernel.get_kernel_source()
81+
assert "__all_sync" in src or "__all" in src, f"Expected __all_sync/__all in source:\n{src}"
82+
kernel(b)
83+
assert torch.all(b != 0), f"Expected all non-zero, got {b}"
84+
85+
86+
# ---------------------------------------------------------------------------
87+
# ballot_sync
88+
# ---------------------------------------------------------------------------
89+
90+
91+
@tilelang.jit
92+
def kernel_ballot_sync():
93+
"""Only lane 0 has a non-zero predicate → ballot bit 0 must be set."""
94+
95+
@T.prim_func
96+
def main(
97+
B: T.Tensor((32,), "int32"),
98+
):
99+
with T.Kernel(1, threads=32):
100+
tx = T.get_thread_binding()
101+
mask = T.ballot_sync(0xFFFFFFFF, tx == 0)
102+
B[tx] = T.cast(mask, "int32")
103+
104+
return main
105+
106+
107+
@tilelang.testing.requires_cuda
108+
def test_ballot_sync():
109+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
110+
kernel = kernel_ballot_sync()
111+
src = kernel.get_kernel_source()
112+
assert "__ballot_sync" in src or "__ballot" in src, f"Expected __ballot_sync/__ballot in source:\n{src}"
113+
kernel(b)
114+
# All lanes read the same ballot value; bit 0 must be set (lane 0 had pred=1)
115+
assert int(b[0]) & 1, f"Expected bit 0 set in ballot result, got {b[0]:#010x}"
116+
117+
118+
# ---------------------------------------------------------------------------
119+
# ballot (full-warp convenience wrapper)
120+
# ---------------------------------------------------------------------------
121+
122+
123+
@tilelang.jit
124+
def kernel_ballot():
125+
"""All lanes pass predicate 1 → all 32 bits in ballot must be set."""
126+
127+
@T.prim_func
128+
def main(
129+
B: T.Tensor((32,), "int32"),
130+
):
131+
with T.Kernel(1, threads=32):
132+
tx = T.get_thread_binding()
133+
mask = T.ballot(1)
134+
B[tx] = T.cast(mask, "int32")
135+
136+
return main
137+
138+
139+
@tilelang.testing.requires_cuda
140+
def test_ballot():
141+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
142+
kernel = kernel_ballot()
143+
src = kernel.get_kernel_source()
144+
assert "__ballot_sync" in src or "__ballot" in src, f"Expected __ballot_sync/__ballot in source:\n{src}"
145+
kernel(b)
146+
# With predicate=1 for all 32 lanes the mask should be 0xFFFFFFFF;
147+
# stored as int32 this is -1.
148+
assert int(b[0]) == -1 or int(b[0]) == 0xFFFFFFFF, f"Expected 0xFFFFFFFF (-1 as int32), got {int(b[0])}"
149+
150+
151+
# ---------------------------------------------------------------------------
152+
# activemask
153+
# ---------------------------------------------------------------------------
154+
155+
156+
@tilelang.jit
157+
def kernel_activemask():
158+
"""All 32 threads are active → activemask should equal 0xFFFFFFFF."""
159+
160+
@T.prim_func
161+
def main(
162+
B: T.Tensor((32,), "int32"),
163+
):
164+
with T.Kernel(1, threads=32):
165+
tx = T.get_thread_binding()
166+
mask = T.activemask()
167+
B[tx] = T.cast(mask, "int32")
168+
169+
return main
170+
171+
172+
@tilelang.testing.requires_cuda
173+
def test_activemask():
174+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
175+
kernel = kernel_activemask()
176+
src = kernel.get_kernel_source()
177+
assert "__activemask" in src or "__ballot" in src, f"Expected __activemask/__ballot in source:\n{src}"
178+
kernel(b)
179+
# All 32 lanes active → 0xFFFFFFFF; as int32 this is -1.
180+
assert int(b[0]) == -1 or int(b[0]) == 0xFFFFFFFF, f"Expected 0xFFFFFFFF (-1 as int32), got {int(b[0])}"
181+
182+
183+
# ---------------------------------------------------------------------------
184+
# syncthreads_count
185+
# ---------------------------------------------------------------------------
186+
187+
188+
@tilelang.jit
189+
def kernel_syncthreads_count():
190+
"""Exactly half the threads (lanes 0–15) pass predicate 1."""
191+
192+
@T.prim_func
193+
def main(
194+
B: T.Tensor((32,), "int32"),
195+
):
196+
with T.Kernel(1, threads=32):
197+
tx = T.get_thread_binding()
198+
cnt = T.syncthreads_count(tx < 16)
199+
B[tx] = cnt
200+
201+
return main
202+
203+
204+
@tilelang.testing.requires_cuda
205+
def test_syncthreads_count():
206+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
207+
kernel = kernel_syncthreads_count()
208+
src = kernel.get_kernel_source()
209+
assert "__syncthreads_count" in src, f"Expected __syncthreads_count in source:\n{src}"
210+
kernel(b)
211+
assert torch.all(b == 16), f"Expected all 16, got {b}"
212+
213+
214+
# ---------------------------------------------------------------------------
215+
# syncthreads_and
216+
# ---------------------------------------------------------------------------
217+
218+
219+
@tilelang.jit
220+
def kernel_syncthreads_and_true():
221+
"""All threads pass predicate 1 → syncthreads_and returns non-zero."""
222+
223+
@T.prim_func
224+
def main(
225+
B: T.Tensor((32,), "int32"),
226+
):
227+
with T.Kernel(1, threads=32):
228+
tx = T.get_thread_binding()
229+
result = T.syncthreads_and(1)
230+
B[tx] = result
231+
232+
return main
233+
234+
235+
@tilelang.jit
236+
def kernel_syncthreads_and_false():
237+
"""Thread 0 passes predicate 0 → syncthreads_and returns 0."""
238+
239+
@T.prim_func
240+
def main(
241+
B: T.Tensor((32,), "int32"),
242+
):
243+
with T.Kernel(1, threads=32):
244+
tx = T.get_thread_binding()
245+
result = T.syncthreads_and(tx != 0)
246+
B[tx] = result
247+
248+
return main
249+
250+
251+
@tilelang.testing.requires_cuda
252+
def test_syncthreads_and():
253+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
254+
kernel = kernel_syncthreads_and_true()
255+
src = kernel.get_kernel_source()
256+
assert "__syncthreads_and" in src, f"Expected __syncthreads_and in source:\n{src}"
257+
kernel(b)
258+
assert torch.all(b != 0), f"Expected all non-zero, got {b}"
259+
260+
b2 = torch.zeros((32,), device="cuda", dtype=torch.int32)
261+
kernel2 = kernel_syncthreads_and_false()
262+
kernel2(b2)
263+
assert torch.all(b2 == 0), f"Expected all 0, got {b2}"
264+
265+
266+
# ---------------------------------------------------------------------------
267+
# syncthreads_or
268+
# ---------------------------------------------------------------------------
269+
270+
271+
@tilelang.jit
272+
def kernel_syncthreads_or_true():
273+
"""At least one thread (lane 0) passes predicate 1 → syncthreads_or != 0."""
274+
275+
@T.prim_func
276+
def main(
277+
B: T.Tensor((32,), "int32"),
278+
):
279+
with T.Kernel(1, threads=32):
280+
tx = T.get_thread_binding()
281+
result = T.syncthreads_or(tx == 0)
282+
B[tx] = result
283+
284+
return main
285+
286+
287+
@tilelang.jit
288+
def kernel_syncthreads_or_false():
289+
"""No thread passes predicate 1 → syncthreads_or returns 0."""
290+
291+
@T.prim_func
292+
def main(
293+
B: T.Tensor((32,), "int32"),
294+
):
295+
with T.Kernel(1, threads=32):
296+
tx = T.get_thread_binding()
297+
result = T.syncthreads_or(0)
298+
B[tx] = result
299+
300+
return main
301+
302+
303+
@tilelang.testing.requires_cuda
304+
def test_syncthreads_or():
305+
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
306+
kernel = kernel_syncthreads_or_true()
307+
src = kernel.get_kernel_source()
308+
assert "__syncthreads_or" in src, f"Expected __syncthreads_or in source:\n{src}"
309+
kernel(b)
310+
assert torch.all(b != 0), f"Expected all non-zero, got {b}"
311+
312+
b2 = torch.zeros((32,), device="cuda", dtype=torch.int32)
313+
kernel2 = kernel_syncthreads_or_false()
314+
kernel2(b2)
315+
assert torch.all(b2 == 0), f"Expected all 0, got {b2}"
316+
317+
318+
if __name__ == "__main__":
319+
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@
100100
from .builtin import stg64 as stg64 # noqa: F401
101101
from .builtin import stg128 as stg128 # noqa: F401
102102
from .builtin import stg256 as stg256 # noqa: F401
103+
from .builtin import any_sync as any_sync # noqa: F401
104+
from .builtin import all_sync as all_sync # noqa: F401
105+
from .builtin import ballot_sync as ballot_sync # noqa: F401
106+
from .builtin import ballot as ballot # noqa: F401
107+
from .builtin import activemask as activemask # noqa: F401
108+
from .builtin import syncthreads_count as syncthreads_count # noqa: F401
109+
from .builtin import syncthreads_and as syncthreads_and # noqa: F401
110+
from .builtin import syncthreads_or as syncthreads_or # noqa: F401
103111

104112
from .utils import index_to_coordinates # noqa: F401
105113

0 commit comments

Comments
 (0)