Skip to content

[Codegen][GPU] Add DistributeArgCompare pattern#23793

Open
bangtianliu wants to merge 3 commits intoiree-org:mainfrom
bangtianliu:argcompare-distribute-pattern
Open

[Codegen][GPU] Add DistributeArgCompare pattern#23793
bangtianliu wants to merge 3 commits intoiree-org:mainfrom
bangtianliu:argcompare-distribute-pattern

Conversation

@bangtianliu
Copy link
Copy Markdown
Contributor

@bangtianliu bangtianliu commented Mar 16, 2026

This PR adds the DistributeArgCompare pattern to distribute iree_vector_ext.arg_compare operations across GPU threads and subgroups.

For supported comparators, we use a ballot-based approach that leverages gpu.subgroup_reduce + gpu.ballot for reduction.
Supported comparators include:

  1. Direct comparison on values (e.g., arith.cmpf ogt for argmax)
  2. Same unary op applied to both arguments before comparison (e.g., math.absf for argmax of absolute values)

Unsupported comparators fall back to the portable butterfly shuffle approach. Currently, this is mainly used for argmax/argmin operations, but we can extend support for additional comparators as needed.

Issue: #23005
Assisted-by: Claude Code

@bangtianliu bangtianliu marked this pull request as draft March 16, 2026 06:03
@bangtianliu bangtianliu force-pushed the argcompare-distribute-pattern branch 5 times, most recently from 4c874b1 to 9e800bb Compare March 17, 2026 00:30
@bangtianliu bangtianliu marked this pull request as ready for review March 17, 2026 00:35
@bangtianliu bangtianliu requested a review from Max191 March 17, 2026 00:35
@bangtianliu bangtianliu force-pushed the argcompare-distribute-pattern branch 2 times, most recently from 8bb34ff to f46ba08 Compare March 18, 2026 05:04
@bangtianliu bangtianliu marked this pull request as draft March 18, 2026 19:48
@bangtianliu bangtianliu force-pushed the argcompare-distribute-pattern branch 4 times, most recently from 06be48b to 3f67dda Compare March 30, 2026 19:28
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu force-pushed the argcompare-distribute-pattern branch from 3f67dda to 69599c5 Compare March 30, 2026 19:37
@bangtianliu bangtianliu marked this pull request as ready for review March 30, 2026 21:18
@bangtianliu bangtianliu requested a review from sommerlukas March 30, 2026 21:18
Copy link
Copy Markdown
Contributor

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on style and tests. For the distribution logic, it would be good to also get @Groverkss's eyes on this.

@bangtianliu bangtianliu requested a review from sommerlukas March 31, 2026 16:17
@bangtianliu
Copy link
Copy Markdown
Contributor Author

Some comments on style and tests. For the distribution logic, it would be good to also get @Groverkss's eyes on this.

Sure, thanks for your time reviewing this PR.

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu force-pushed the argcompare-distribute-pattern branch from 14e99c0 to a2f50ff Compare March 31, 2026 19:34
@bangtianliu
Copy link
Copy Markdown
Contributor Author

cc @Groverkss for review

Copy link
Copy Markdown
Contributor

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more nits and questions.

/// Returns arith.cmpf for floating-point types and arith.cmpi for integers.
static Value createEqualityComparison(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType())) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert here that lhs and rhs have the same type?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the arith.cmpf and arith.cmpi verifiers will catch any type mismatch.

  - Extract shared helpers and to deduplicate shape computation logic between DistributeMultiReduction and DistributeArgCompare
  - Simplify broadcastShape computation by deriving from distributed input shape instead of manual construction
  - Reuse existing elemTy/indexElemTy variables instead of repeated getElementType() calls
  - Simplify resultValue initialization by setting common case first
  - Fix inaccurate comment about yielded value check

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants