[Analysis] Prohibiting using non-parallel iterators in Fragment access#1884
[Analysis] Prohibiting using non-parallel iterators in Fragment access#1884SiriusNEO wants to merge 2 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughRefactors the fragment loop checker to use a loop-stack traversal, tightens rules forbidding use of serial or symbolic loop iterators to index fragment buffers, renames collect_local_buffer_accesses to collect_fragment_accesses, and adds new tests covering valid and invalid serial/symbolic indexing scenarios. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (3)
testing/python/analysis/test_tilelang_fragment_loop_checker.py (2)
71-85: Parameterblockis shadowed andlength,dtypeare unused.The function signature includes
length,block, anddtypeparameters, butblockis immediately reassigned on line 75, andlength/dtypeare never used. This appears to be copy-paste from other test functions. Consider removing unused parameters or using them consistently.♻️ Suggested cleanup
`@tilelang.jit` def invalid_indexing_with_serial_sp( - length=256, block=16, dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128 + accum_dtype: T.dtype = T.float32, num_threads: int = 128 ): block = 16🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/analysis/test_tilelang_fragment_loop_checker.py` around lines 71 - 85, The test function invalid_indexing_with_serial_sp declares parameters length, block, and dtype but then reassigns block and never uses length or dtype; fix by removing unused parameters from the signature (keep only num_threads and accum_dtype if needed) or alternatively use the declared parameters inside the body (e.g., use the passed-in block/length/dtype instead of hardcoding block = 16 and the alloc_fragment shape/type), and update the inner T.prim_func main accordingly so there is no shadowing of block and no unused parameters.
88-102: Same unused parameter issue asinvalid_indexing_with_serial_sp.Same feedback applies:
length,block(shadowed), anddtypeare unused.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/analysis/test_tilelang_fragment_loop_checker.py` around lines 88 - 102, The function invalid_indexing_with_serial_ps has unused parameters (length, block, dtype) and additionally reassigns/shadows block with block = 16 before using T.Parallel(block); either remove the unused parameters from the signature or actually use them (e.g., use the passed-in length and block values for allocations/loop bounds and dtype for alloc_fragment), and eliminate the shadowing assignment so the parameter value drives T.Parallel(block) and other uses in main; update references to data_frag = T.alloc_fragment([128], accum_dtype) to use length/ dtype if you keep those parameters and ensure no local variable reassigns the parameter names (remove or rename block = 16).tilelang/analysis/fragment_loop_checker.py (1)
59-105: Consider extracting repeated analyzer logic into a helper.The analyzer creation and iteration pattern is duplicated in Check 1 (lines 82-91) and Check 2 (lines 92-102). While the current implementation is correct and readable, you could reduce duplication by extracting a helper function.
♻️ Optional refactor to reduce duplication
+def _is_loop_var_used_in_indices(loop_var: Var, indices) -> bool: + """Check if a loop variable is used in any of the given indices.""" + analyzer = _LoopVarUseAnalyzer(loop_var) + for index in indices: + analyzer.visit_expr(index) + return analyzer.used + `@tir.functor.visitor` class _FragmentLoopCheckVisitor(PyStmtExprVisitor): # ... (in visit_for_ method) for buffer_access in buffer_accesses: indices = buffer_access.indices # Check 1 for loop in loops_with_symbolic_ranges: - analyzer = _LoopVarUseAnalyzer(loop.loop_var) - for index in indices: - analyzer.visit_expr(index) - if analyzer.used: + if _is_loop_var_used_in_indices(loop.loop_var, indices): raise ValueError(...) # Check 2 for loop in non_parallel_loops: - analyzer = _LoopVarUseAnalyzer(loop.loop_var) - for index in indices: - analyzer.visit_expr(index) - if analyzer.used: + if _is_loop_var_used_in_indices(loop.loop_var, indices): raise ValueError(...)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/analysis/fragment_loop_checker.py` around lines 59 - 105, The visit_for_ method duplicates the pattern of creating an _LoopVarUseAnalyzer, visiting each index, and raising on analyzer.used in both the loops_with_symbolic_ranges and non_parallel_loops checks; extract a small helper (e.g., check_loop_index_usage(loop, indices, error_message)) that instantiates _LoopVarUseAnalyzer(loop.loop_var), iterates analyzer.visit_expr over indices, and raises the provided error_message when analyzer.used is true, then call this helper from the two places (passing the appropriate loop and the formatted error text referencing loop.loop_var, loop.min, loop.extent as needed) to remove the duplicated logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@testing/python/analysis/test_tilelang_fragment_loop_checker.py`:
- Around line 71-85: The test function invalid_indexing_with_serial_sp declares
parameters length, block, and dtype but then reassigns block and never uses
length or dtype; fix by removing unused parameters from the signature (keep only
num_threads and accum_dtype if needed) or alternatively use the declared
parameters inside the body (e.g., use the passed-in block/length/dtype instead
of hardcoding block = 16 and the alloc_fragment shape/type), and update the
inner T.prim_func main accordingly so there is no shadowing of block and no
unused parameters.
- Around line 88-102: The function invalid_indexing_with_serial_ps has unused
parameters (length, block, dtype) and additionally reassigns/shadows block with
block = 16 before using T.Parallel(block); either remove the unused parameters
from the signature or actually use them (e.g., use the passed-in length and
block values for allocations/loop bounds and dtype for alloc_fragment), and
eliminate the shadowing assignment so the parameter value drives
T.Parallel(block) and other uses in main; update references to data_frag =
T.alloc_fragment([128], accum_dtype) to use length/ dtype if you keep those
parameters and ensure no local variable reassigns the parameter names (remove or
rename block = 16).
In `@tilelang/analysis/fragment_loop_checker.py`:
- Around line 59-105: The visit_for_ method duplicates the pattern of creating
an _LoopVarUseAnalyzer, visiting each index, and raising on analyzer.used in
both the loops_with_symbolic_ranges and non_parallel_loops checks; extract a
small helper (e.g., check_loop_index_usage(loop, indices, error_message)) that
instantiates _LoopVarUseAnalyzer(loop.loop_var), iterates analyzer.visit_expr
over indices, and raises the provided error_message when analyzer.used is true,
then call this helper from the two places (passing the appropriate loop and the
formatted error text referencing loop.loop_var, loop.min, loop.extent as needed)
to remove the duplicated logic.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/analysis/test_tilelang_fragment_loop_checker.pytilelang/analysis/fragment_loop_checker.py
There was a problem hiding this comment.
Pull request overview
Adds stricter semantic validation to TileLang’s pre-lowering analysis to forbid indexing local/fragment buffers using non-parallel loop iterators, addressing invalid loop-iterator usage patterns reported in issue #1868.
Changes:
- Update
FragmentLoopCheckerto track nested loop context via a loop stack and reject local/fragment buffer indexing that uses non-parallel loop vars. - Extend validation to also reject indexing with
T.Parallelloop vars when the parallel loop range is symbolic. - Add new regression tests covering serial/parallel loop nesting patterns for fragment indexing.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
tilelang/analysis/fragment_loop_checker.py |
Reworks fragment/local buffer access validation using a loop stack; adds new non-parallel iterator restriction and updates documentation. |
testing/python/analysis/test_tilelang_fragment_loop_checker.py |
Adds new invalid/valid cases for serial+parallel iterator combinations when indexing fragment buffers. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def valid_indexing_with_serial(length=256, block=16, dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): | ||
| block = 16 | ||
|
|
||
| @T.prim_func | ||
| def main(): | ||
| with T.Kernel(128, threads=num_threads) as _: | ||
| data_frag = T.alloc_fragment([128], accum_dtype) | ||
| for i in T.serial(8): # noqa: B007 | ||
| for j in T.Parallel(block): |
There was a problem hiding this comment.
valid_indexing_with_serial also defines length/block parameters but then overwrites block and never uses length. Consider removing these parameters or using them so the test case remains clear and consistent with its signature.
| def valid_indexing_with_serial(length=256, block=16, dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): | |
| block = 16 | |
| @T.prim_func | |
| def main(): | |
| with T.Kernel(128, threads=num_threads) as _: | |
| data_frag = T.alloc_fragment([128], accum_dtype) | |
| for i in T.serial(8): # noqa: B007 | |
| for j in T.Parallel(block): | |
| def valid_indexing_with_serial(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): | |
| @T.prim_func | |
| def main(): | |
| with T.Kernel(128, threads=num_threads) as _: | |
| data_frag = T.alloc_fragment([128], accum_dtype) | |
| for i in T.serial(8): # noqa: B007 | |
| for j in T.Parallel(16): |
| @@ -76,10 +89,20 @@ def visit_for_(self, op: For) -> None: | |||
| f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index " | |||
| "a local/fragment buffer, which is not allowed in Tilelang." | |||
| ) | |||
|
|
|||
| return | |||
| # Check 2 | |||
| for loop in non_parallel_loops: | |||
| analyzer = _LoopVarUseAnalyzer(loop.loop_var) | |||
| for index in indices: | |||
| analyzer.visit_expr(index) | |||
| if analyzer.used: | |||
| raise ValueError( | |||
| "[Tilelang Semantic Check] " | |||
| f"A non-parallel loop iterator {loop.loop_var} is used to index " | |||
| "a local/fragment buffer, which is not allowed in Tilelang." | |||
| ) | |||
|
|
|||
| self.visit_stmt(op.body) | |||
| self.loop_stack.pop() | |||
There was a problem hiding this comment.
visit_for_ mutates self.loop_stack but does not guarantee pop() executes if visit_stmt(op.body) raises (e.g., from a nested invalid access). Wrapping the body traversal in try/finally would keep the visitor state consistent and make the checker safer to reuse/extend.
| """ | ||
| Check whether the fragment accesses are valid. | ||
|
|
||
| This checker will recursively visit all the for loops until it reaches certain "inner most loop". | ||
| Then it will start to check the validity of fragment access in the loop body. We need to maintain a stack of | ||
| loops during the traversal since this is the context/scope of the fragment access. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.loop_stack = [] | ||
|
|
||
| def visit_for_(self, op: For) -> None: | ||
| if op.kind == tir.ForKind.PARALLEL: | ||
| # Fuse consecutive parallel loops | ||
| # Other nested cases are all invalid in TileLang. | ||
| loops = [op] | ||
| child = op.body | ||
| while isinstance(child, For) and child.kind == tir.ForKind.PARALLEL: | ||
| loops.append(child) | ||
| child = child.body | ||
| self.loop_stack.append(op) | ||
| child = op.body | ||
|
|
||
| # Reach the the innermost loop | ||
| # This may cause repeated check for cases like: For1{Stmt1; For2{}; For3{};}; | ||
| # But it's OK since the check is idempotent. | ||
| if not isinstance(child, For): |
There was a problem hiding this comment.
Docstring/comment grammar: "inner most loop" and "Reach the the innermost loop" read as typos. Consider changing to "innermost loop" and "Reach the innermost loop" for clarity.
| @tilelang.jit | ||
| def invalid_indexing_with_serial_sp( | ||
| length=256, block=16, dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128 | ||
| ): | ||
| block = 16 | ||
|
|
There was a problem hiding this comment.
These new test helpers accept parameters (length, block, dtype) that are unused, and block is immediately overwritten. This makes it harder to understand what the test is validating; consider removing the unused parameters or using them so the test expresses the intended shape/configuration.
| @tilelang.jit | ||
| def invalid_indexing_with_serial_ps( | ||
| length=256, block=16, dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128 | ||
| ): | ||
| block = 16 | ||
|
|
There was a problem hiding this comment.
Same issue here: the function signature parameters are unused and block is overwritten immediately. Consider dropping the unused args or making the test configurable by using them.
|
Found some examples which violate the check rules. Need to verify |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tilelang/analysis/fragment_loop_checker.py (2)
59-106: Traversal logic is correct.The push/pop pattern around
visit_stmt(op.body)correctly maintains loop context. The condition at line 66 ensures checks occur at the innermost loop level, and the acknowledged repeated checking (comment on lines 64-65) is indeed idempotent.One optional refinement: consider extracting the long error messages into module-level constants or a dedicated exception class to improve readability and satisfy TRY003.
♻️ Optional: Extract error messages
# At module level _ERR_SYMBOLIC_RANGE = ( "[Tilelang Semantic Check] " "Loop variable {var} in a T.Parallel loop with symbolic range " "(min={min}, extent={extent}) is used to index a fragment buffer, " "which is not allowed in Tilelang." ) _ERR_NON_PARALLEL = ( "[Tilelang Semantic Check] " "A non-parallel loop iterator {var} is used to index a fragment buffer, " "which is not allowed in Tilelang." ) # Then in the check: raise ValueError(_ERR_SYMBOLIC_RANGE.format(var=loop.loop_var, min=loop.min, extent=loop.extent)) raise ValueError(_ERR_NON_PARALLEL.format(var=loop.loop_var))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/analysis/fragment_loop_checker.py` around lines 59 - 106, The long literal error messages inside visit_for_ make the function noisy; extract them to module-level constants (e.g., _ERR_SYMBOLIC_RANGE, _ERR_NON_PARALLEL) or wrap them in a small dedicated exception class and use formatted messages when raising ValueError in visit_for_; update the raise sites that reference loop.loop_var, loop.min, loop.extent accordingly so visit_for_ remains functionally identical but with cleaner, reusable message constants and improved readability.
109-119: Consider clarifying the docstring for rule 1.The current wording "The range of loop can not be symbolic" could be interpreted as forbidding symbolic ranges entirely. The actual behavior is more nuanced: parallel loops with symbolic ranges are allowed, but their loop variables cannot be used to index fragment buffers.
📝 Suggested docstring clarification
""" When using T.Parallel over a local/fragment buffer, there are several restrictions: to ensure that the parallelization is valid. - 1. The range of loop can not be symbolic. + 1. Loop variables from T.Parallel loops with symbolic ranges cannot be used to index fragment buffers. 2. Any access/indexing of the fragment buffer should not contain other types of iterators (like loop variables from T.Serial). Returns: A prim_func_pass that applies the transformation """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/analysis/fragment_loop_checker.py` around lines 109 - 119, The docstring for FragmentLoopChecker is misleading about rule 1; update it to state that parallel loops with symbolic ranges are permitted but their loop variables must not be used to index fragment/local buffers (i.e., symbolic loop bounds are allowed only if the loop var is not used to access fragment buffers). Edit the docstring in the FragmentLoopChecker function to replace "The range of loop can not be symbolic." with a clarified sentence reflecting this nuance and optionally add a short example or note referencing fragment buffer indexing restrictions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tilelang/analysis/fragment_loop_checker.py`:
- Around line 59-106: The long literal error messages inside visit_for_ make the
function noisy; extract them to module-level constants (e.g.,
_ERR_SYMBOLIC_RANGE, _ERR_NON_PARALLEL) or wrap them in a small dedicated
exception class and use formatted messages when raising ValueError in
visit_for_; update the raise sites that reference loop.loop_var, loop.min,
loop.extent accordingly so visit_for_ remains functionally identical but with
cleaner, reusable message constants and improved readability.
- Around line 109-119: The docstring for FragmentLoopChecker is misleading about
rule 1; update it to state that parallel loops with symbolic ranges are
permitted but their loop variables must not be used to index fragment/local
buffers (i.e., symbolic loop bounds are allowed only if the loop var is not used
to access fragment buffers). Edit the docstring in the FragmentLoopChecker
function to replace "The range of loop can not be symbolic." with a clarified
sentence reflecting this nuance and optionally add a short example or note
referencing fragment buffer indexing restrictions.
Check the cases: #1868
Summary by CodeRabbit
Tests
Bug Fixes
Documentation