Skip to content

[Torch] Handle dynamic head dimensions for attention#23636

Merged
IanWood1 merged 6 commits intoiree-org:mainfrom
IanWood1:support_dyn_head_dim
Apr 2, 2026
Merged

[Torch] Handle dynamic head dimensions for attention#23636
IanWood1 merged 6 commits intoiree-org:mainfrom
IanWood1:support_dyn_head_dim

Conversation

@IanWood1
Copy link
Copy Markdown
Member

@IanWood1 IanWood1 commented Mar 3, 2026

Uses createOrFold to handle the dynamic case by creating a tensor.dim + math.rsqrt op.

Copy link
Copy Markdown
Contributor

@keshavvinayak01 keshavvinayak01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I mistakenly opened up a duplicate #23680 as well. This implementation LGTM!

@keshavvinayak01
Copy link
Copy Markdown
Contributor

nit though, would it make sense to add this?

if (headDim != ShapedType::kDynamic) {
      double dk = 1.0 / std::sqrt(static_cast<double>(headDim));
      scale = arith::ConstantOp::create(rewriter, loc, targetType,
                                        rewriter.getFloatAttr(targetType, dk));
    } else {

IanWood1 added 3 commits March 6, 2026 07:57
Compute the attention scale as rsqrt(head_dim) using a single code path
for both static and dynamic head dimensions. For static dims,
createOrFold constant-folds the dim/cast/sitofp chain; for dynamic dims,
the full runtime computation is emitted.

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
math::RsqrtOp lacks a constant folder, so the static head dim case
was no longer fully folded to a single constant. Switch to
math::SqrtOp + arith::DivFOp which both have folders, restoring
the original folded output (e.g. arith.constant 5.000000e-01).

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
This reverts commit d556726.

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
@IanWood1 IanWood1 force-pushed the support_dyn_head_dim branch from d556726 to 837958e Compare March 6, 2026 16:02
@IanWood1
Copy link
Copy Markdown
Member Author

IanWood1 commented Mar 6, 2026

I added a folder for math.rsqrt llvm/llvm-project#184443 so I'll just use that op here.

@IanWood1
Copy link
Copy Markdown
Member Author

IanWood1 commented Mar 6, 2026

nit though, would it make sense to add this?

if (headDim != ShapedType::kDynamic) {
      double dk = 1.0 / std::sqrt(static_cast<double>(headDim));
      scale = arith::ConstantOp::create(rewriter, loc, targetType,
                                        rewriter.getFloatAttr(targetType, dk));
    } else {

My thought was that using createOrFold would do basically the same thing with only a small overhead.

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
Copy link
Copy Markdown
Contributor

@keshavvinayak01 keshavvinayak01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be merging this?

@IanWood1 IanWood1 enabled auto-merge (squash) April 1, 2026 19:12
@keshavvinayak01
Copy link
Copy Markdown
Contributor

Seems like a genuine test failure:

The following tests FAILED:
	 63 - iree/compiler/plugins/input/Torch/InputConversion/test/attention.mlir.test (Failed) iree/compiler/plugins/input/Torch/InputConversion/test test-type=lit-test
Error: Process completed with exit code 8.

createOrFold folds the entire rsqrt(head_dim) computation at compile
time when the head dimension is static, producing a single constant
(e.g. 0.5 for head_dim=4) rather than the unfolded arith.constant +
math.rsqrt pattern. Update CHECK lines accordingly.

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
@IanWood1 IanWood1 force-pushed the support_dyn_head_dim branch from 0ea142b to f9bda38 Compare April 2, 2026 15:45
@IanWood1
Copy link
Copy Markdown
Member Author

IanWood1 commented Apr 2, 2026

Seems like a genuine test failure:

The following tests FAILED:
	 63 - iree/compiler/plugins/input/Torch/InputConversion/test/attention.mlir.test (Failed) iree/compiler/plugins/input/Torch/InputConversion/test test-type=lit-test
Error: Process completed with exit code 8.

Fixed, I forgot about my changes that added a folder for math.rsqrt.

@IanWood1 IanWood1 merged commit a81ad4d into iree-org:main Apr 2, 2026
61 of 62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants