[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858
[Feature] Add TIR builtins for warp-level vote and block-level predicate sync#1858sepcnt wants to merge 1 commit intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds wide-ranging GPU synchronization, warp vote/ballot, intra-warp shuffle, and expanded atomic primitives to TileLang's public API, with CUDA/HIP backend mappings and documentation examples; also adds unit tests exercising warp-vote intrinsics and exports the new symbols. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/builtin.py (1)
941-959:ballot()duplicates logic instead of delegating toballot_sync().The docstring says "convenience wrapper around
ballot_sync" but the implementation reimplements the HIP/CUDA branching rather than callingballot_sync(0xFFFFFFFF, predicate). If theballot_syncimplementation ever changes (e.g., different cast strategy),ballotwill silently diverge.♻️ Proposed refactor to delegate to ballot_sync
def ballot(predicate: int | PrimExpr) -> PrimExpr: ... - if _IS_HIP_AVAILABLE: - return tir.cast("uint32", tir.call_extern("uint64", "__ballot", predicate)) - else: - return tir.call_extern("uint32", "__ballot_sync", tir.const(0xFFFFFFFF, "uint32"), predicate) + return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate)Similarly,
activemask()on HIP (line 972) reimplements the same__ballot+ cast pattern. Consider delegating:def activemask() -> PrimExpr: ... - if _IS_HIP_AVAILABLE: - return tir.cast("uint32", tir.call_extern("uint64", "__ballot", tir.const(1, "int32"))) - else: + if not _IS_HIP_AVAILABLE: return tir.call_extern("uint32", "__activemask") + return ballot(tir.const(1, "int32"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/builtin.py` around lines 941 - 959, The ballot() function duplicates the HIP/CUDA branching instead of delegating to ballot_sync; change ballot to simply return ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate) (ensuring types match) so any future ballot_sync changes apply here; likewise update activemask() to delegate to ballot_sync/activemask_sync equivalent instead of reimplementing the __ballot + cast pattern (use the same uint32/uint64 casting behavior that ballot_sync already implements) and remove the duplicated call_extern/cast logic in ballot and activemask while keeping their public signatures unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-40: The test declares an unused tensor parameter A in
kernel_any_sync.main and has an incorrect docstring; remove A from the kernel
signature (delete the T.Tensor((1,), "int32") parameter and any references to
A), update the docstring to accurately describe that lane 0 participates in the
any_sync predicate (e.g., "Lane 0 sets the predicate; all lanes use any_sync to
see if any lane matched"), and update any test invocation that passed an A
argument so it only supplies the B tensor to kernel_any_sync.
In `@tilelang/language/builtin.py`:
- Around line 918-938: The current ballot_sync function truncates HIP's 64-bit
(__ballot) result to uint32; modify the HIP branch (where _IS_HIP_AVAILABLE is
checked) to return the raw uint64 result via tir.call_extern("uint64",
"__ballot", predicate) instead of tir.cast("uint32", ...), and update the
function docstring and the inline comment to state that on HIP the function
returns uint64 (callers should narrow to uint32 when they know wavefronts are
32-wide); reference symbols: ballot_sync, _IS_HIP_AVAILABLE, tir.call_extern,
tir.cast.
---
Nitpick comments:
In `@tilelang/language/builtin.py`:
- Around line 941-959: The ballot() function duplicates the HIP/CUDA branching
instead of delegating to ballot_sync; change ballot to simply return
ballot_sync(tir.const(0xFFFFFFFF, "uint32"), predicate) (ensuring types match)
so any future ballot_sync changes apply here; likewise update activemask() to
delegate to ballot_sync/activemask_sync equivalent instead of reimplementing the
__ballot + cast pattern (use the same uint32/uint64 casting behavior that
ballot_sync already implements) and remove the duplicated call_extern/cast logic
in ballot and activemask while keeping their public signatures unchanged.
| @tilelang.jit | ||
| def kernel_any_sync(): | ||
| """Lane 0 writes 1 to A; all lanes use any_sync to see if any lane wrote.""" | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((1,), "int32"), | ||
| B: T.Tensor((32,), "int32"), | ||
| ): | ||
| with T.Kernel(1, threads=32): | ||
| tx = T.get_thread_binding() | ||
| val = T.any_sync(0xFFFFFFFF, tx == 0) | ||
| B[tx] = val | ||
|
|
||
| return main |
There was a problem hiding this comment.
Unused A parameter in kernel_any_sync.
The A tensor is declared in the kernel signature but never read or written. The docstring also incorrectly states "Lane 0 writes 1 to A." This adds a needless allocation in the test and a confusing docstring.
🧹 Proposed fix
`@tilelang.jit`
def kernel_any_sync():
- """Lane 0 writes 1 to A; all lanes use any_sync to see if any lane wrote."""
+ """Lane 0 has predicate 1; all lanes use any_sync to check."""
`@T.prim_func`
def main(
- A: T.Tensor((1,), "int32"),
B: T.Tensor((32,), "int32"),
):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding()
val = T.any_sync(0xFFFFFFFF, tx == 0)
B[tx] = val
return mainAnd update the test call accordingly:
def test_any_sync():
- a = torch.zeros((1,), device="cuda", dtype=torch.int32)
b = torch.zeros((32,), device="cuda", dtype=torch.int32)
kernel = kernel_any_sync()
src = kernel.get_kernel_source()
assert "__any_sync" in src or "__any" in src, f"Expected __any_sync/__any in source:\n{src}"
- kernel(a, b)
+ kernel(b)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @tilelang.jit | |
| def kernel_any_sync(): | |
| """Lane 0 writes 1 to A; all lanes use any_sync to see if any lane wrote.""" | |
| @T.prim_func | |
| def main( | |
| A: T.Tensor((1,), "int32"), | |
| B: T.Tensor((32,), "int32"), | |
| ): | |
| with T.Kernel(1, threads=32): | |
| tx = T.get_thread_binding() | |
| val = T.any_sync(0xFFFFFFFF, tx == 0) | |
| B[tx] = val | |
| return main | |
| `@tilelang.jit` | |
| def kernel_any_sync(): | |
| """Lane 0 has predicate 1; all lanes use any_sync to check.""" | |
| `@T.prim_func` | |
| def main( | |
| B: T.Tensor((32,), "int32"), | |
| ): | |
| with T.Kernel(1, threads=32): | |
| tx = T.get_thread_binding() | |
| val = T.any_sync(0xFFFFFFFF, tx == 0) | |
| B[tx] = val | |
| return main |
🧰 Tools
🪛 Ruff (0.15.1)
[warning] 32-32: Unused function argument: A
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/language/test_tilelang_language_warp_vote.py` around lines 26
- 40, The test declares an unused tensor parameter A in kernel_any_sync.main and
has an incorrect docstring; remove A from the kernel signature (delete the
T.Tensor((1,), "int32") parameter and any references to A), update the docstring
to accurately describe that lane 0 participates in the any_sync predicate (e.g.,
"Lane 0 sets the predicate; all lanes use any_sync to see if any lane matched"),
and update any test invocation that passed an A argument so it only supplies the
B tensor to kernel_any_sync.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
103-110: Optional: Remove redundant# noqa: F401directives.Ruff flags these as unused since
F401is not enabled in the project's Ruff configuration. The same pattern exists on lines 94–102 (pre-existing, not flagged here). Consider a one-time cleanup of all such directives across lines 94–110 once the existing pattern is addressed.🧹 Proposed cleanup (lines 103–110 only)
-from .builtin import any_sync as any_sync # noqa: F401 -from .builtin import all_sync as all_sync # noqa: F401 -from .builtin import ballot_sync as ballot_sync # noqa: F401 -from .builtin import ballot as ballot # noqa: F401 -from .builtin import activemask as activemask # noqa: F401 -from .builtin import syncthreads_count as syncthreads_count # noqa: F401 -from .builtin import syncthreads_and as syncthreads_and # noqa: F401 -from .builtin import syncthreads_or as syncthreads_or # noqa: F401 +from .builtin import any_sync as any_sync +from .builtin import all_sync as all_sync +from .builtin import ballot_sync as ballot_sync +from .builtin import ballot as ballot +from .builtin import activemask as activemask +from .builtin import syncthreads_count as syncthreads_count +from .builtin import syncthreads_and as syncthreads_and +from .builtin import syncthreads_or as syncthreads_or🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/__init__.py` around lines 103 - 110, Remove the redundant "# noqa: F401" directives from the import statements that re-export builtin symbols; specifically edit the import lines that reference any_sync, all_sync, ballot_sync, ballot, activemask, syncthreads_count, syncthreads_and, and syncthreads_or in __init__.py and delete the trailing " # noqa: F401" from each import; ensure the names remain imported (no other code changes) and run the linter/flake check to confirm no warnings remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/programming_guides/instructions.md`:
- Around line 155-158: Add a brief HIP version note stating that
T.syncthreads_count, T.syncthreads_and, and T.syncthreads_or require ROCm/HIP ≥
7.0 (they are only under development in HIP ≤ 6.2); update the three entries
(T.syncthreads_count(predicate), T.syncthreads_and(predicate),
T.syncthreads_or(predicate)) to append “HIP: ROCm/hip ≥ 7.0” or expand the
existing HIP note that covers vote/ballot caveats to explicitly call out these
predicated block-wide sync functions and their minimum HIP version requirement.
In `@tilelang/language/builtin.py`:
- Around line 982-1024: The docstrings for syncthreads_count, syncthreads_and,
and syncthreads_or incorrectly claim the intrinsics map to CUDA and HIP
unconditionally; update each function's docstring (syncthreads_count,
syncthreads_and, syncthreads_or) to note the HIP/ROCm version constraints: that
the __syncthreads_count/and/or variants are under development in HIP 6.2 and
that __sync variants became available in ROCm 7.0 (and are enabled by default
there), so on older HIP/ROCm stacks these intrinsics may not be available; keep
the CUDA mapping text unchanged and add a brief sentence about the HIP version
caveat.
---
Duplicate comments:
In `@testing/python/language/test_tilelang_language_warp_vote.py`:
- Around line 26-52: The kernel declares A in kernel_any_sync -> main but never
uses it and the docstring is wrong; either remove the unused parameter A from
main and update the docstring and test_any_sync to stop allocating/passing a, or
implement the intended behavior: have lane 0 write 1 into A (e.g., inside
T.Kernel when tx == 0 set A[0]=1) and change the any_sync predicate to read
A[0]==1 so the call sites (test_any_sync's allocation/passing of a) remain
correct; update kernel_any_sync's docstring accordingly.
In `@tilelang/language/builtin.py`:
- Around line 918-974: The HIP implementations of ballot_sync, ballot, and
activemask incorrectly cast the uint64 result of __ballot to uint32, discarding
lanes 32–63 on wave-64 targets; update the HIP paths in ballot_sync, ballot, and
activemask to return 64-bit masks (use tir.call_extern with "uint64" and return
a uint64 PrimExpr, and avoid casting to "uint32"), and ensure ballot_sync's mask
parameter and any callers that expect a 32-bit return are adjusted to use 64-bit
mask types so full wave-64 lane data is preserved.
---
Nitpick comments:
In `@tilelang/language/__init__.py`:
- Around line 103-110: Remove the redundant "# noqa: F401" directives from the
import statements that re-export builtin symbols; specifically edit the import
lines that reference any_sync, all_sync, ballot_sync, ballot, activemask,
syncthreads_count, syncthreads_and, and syncthreads_or in __init__.py and delete
the trailing " # noqa: F401" from each import; ensure the names remain imported
(no other code changes) and run the linter/flake check to confirm no warnings
remain.
| Block-wide predicated sync | ||
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). | ||
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). | ||
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). |
There was a problem hiding this comment.
Missing HIP version note for syncthreads_count/and/or.
The block-wide predicated sync entries do not mention the ROCm ≥ 7.0 requirement noted in the PR description. The __syncthreads_count(int), __syncthreads_and(int), and __syncthreads_or(int) functions are under development in HIP 6.2 and below. The HIP note at line 166 addresses vote/ballot HIP caveats but omits this. Consider appending a brief version note to the syncthreads_count/and/or entries or expanding the HIP note to cover them.
📝 Suggested doc patch
Block-wide predicated sync
-- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`).
-- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`).
-- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`).
+- `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)*
+- `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)*📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Block-wide predicated sync | |
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). | |
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). | |
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). | |
| Block-wide predicated sync | |
| - `T.syncthreads_count(predicate)` → `int32`: Sync all threads; return count with non-zero predicate (`__syncthreads_count`). *(HIP: requires ROCm ≥ 7.0)* | |
| - `T.syncthreads_and(predicate)` → `int32`: Sync; non-zero iff ALL threads have non-zero predicate (`__syncthreads_and`). *(HIP: requires ROCm ≥ 7.0)* | |
| - `T.syncthreads_or(predicate)` → `int32`: Sync; non-zero iff ANY thread has non-zero predicate (`__syncthreads_or`). *(HIP: requires ROCm ≥ 7.0)* |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/programming_guides/instructions.md` around lines 155 - 158, Add a brief
HIP version note stating that T.syncthreads_count, T.syncthreads_and, and
T.syncthreads_or require ROCm/HIP ≥ 7.0 (they are only under development in HIP
≤ 6.2); update the three entries (T.syncthreads_count(predicate),
T.syncthreads_and(predicate), T.syncthreads_or(predicate)) to append “HIP:
ROCm/hip ≥ 7.0” or expand the existing HIP note that covers vote/ballot caveats
to explicitly call out these predicated block-wide sync functions and their
minimum HIP version requirement.
|
@sepcnt Thanks and would be better to wrap the instruction with |
This pull request adds builtin support for warp-level vote/ballot intrinsics and block-wide predicated synchronization operations.
These primitives are fundamental building blocks for high-performance GPU programming. They enable warp-wide early termination, fast consensus evaluation, and efficient divergence control.
Such capabilities are critical in optimized sorting algorithms, parallel partitioning, and other performance-sensitive kernels, where eliminating redundant computation at the warp level can substantially reduce execution time.
Note: On the ROCm backend, these intrinsics rely on HIP language extensions available in ROCm/HIP >= 7.0. Older ROCm releases may not provide full support for these builtins.
Summary by CodeRabbit
New Features
Documentation
Tests