Add XAttention reference operation#31864
Conversation
9a119ee to
e29e853
Compare
894d8bc to
7e88152
Compare
460a365 to
6be527b
Compare
6be527b to
1113f02
Compare
There was a problem hiding this comment.
The reference tests should be stored in:
src/plugins/template/tests/functional/op_reference/
as others.
There was a problem hiding this comment.
I am aware of these tests. All of these use utilize unnecessary abstractions of OV opset ops and involve the sample plugin to do unit tests. Seeing that there will never be a counterpart to this operation in the OV opset, there is no reason to proliferate the bad design decision to have all reference ops tested through the template plugin just for the sake of consistency.
| OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] | ||
| OPENVINO_ASSERT(out_shape.size() == 3); | ||
| OPENVINO_ASSERT(input_shape[0] == out_shape[0]); | ||
| OPENVINO_ASSERT(input_shape[1] % m_stride == 0); | ||
| OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); | ||
| OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); |
There was a problem hiding this comment.
Remove these assert. the shapes and input validation is done during shape inference in operator implementation.
There was a problem hiding this comment.
Not relevant since there is no shape inference involved by design
| * @param out_shape Shape of the output tensor data. Expected shape is strictly equal to | ||
| * `reshaped_qk_product_shape`. | ||
| */ | ||
| void softmax(const T* reshaped_qk_product_data, |
There was a problem hiding this comment.
Why is required as it is just call for ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2});
There was a problem hiding this comment.
This is a reference operation. It is designed to be understandable and testable first and foremost. I provide a set of functions that each directly map to a phase or sub-phase of the intended HW-accelerated kernel. Softmax is one of these phases and therefore has a separate function. It also has an extra shape check, so having a function wrapper has some utility.
| } | ||
| } | ||
|
|
||
| /** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so |
There was a problem hiding this comment.
| /** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so | |
| /** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, so |
4042108 to
2295ca6
Compare
22b057f to
aa28325
Compare
| cumsum += index_and_largest_score.score; | ||
| retval[head_idx].insert(index_and_largest_score.idx); | ||
| } | ||
| } |
There was a problem hiding this comment.
Hi @vshampor It looks like this implementation selects blocks by globally sorting all blocks per head and accumulating their scores until the threshold is reached. In the X-Attention paper (Section 2.2, https://arxiv.org/abs/2503.16428
), the algorithm selects blocks per query block based on antidiagonal scoring. Could this global accumulation approach lead to differences from the paper’s intended block coverage for each query(https://github.com/mit-han-lab/x-attention/blob/e37988770b9d1bebd489eba011d615f35587ba08/xattn/src/utils.py#L44)?
There was a problem hiding this comment.
Please consider the relative position of this function in the frontend select_blocks function and observe that the block selection is ultimately done among the antidiagonal subset of the full attention matrix. I didn't understand the concerns about globality - in which sense do you consider the sorting to be global and why does it not match to the original algo?
There was a problem hiding this comment.
Simplify the problem, I use https://github.com/mit-han-lab/x-attention/tree/main?tab=readme -Use your test case data in ov file # quick use with your test date in
The selected blocks are inconsistent:THRESHOLD = 0.8
BLOCK_SIZE = 2
STRIDE = 2
# batch=1, head=2, seq_len=4, dim=4
bsz = 1
heads = 2
seq_len = 4
dim = 4
q_data = [
3.144, 8.512, 8.518, -8.386,
7.889, -5.721, 5.507, 4.295,
-6.624, -8.463, 7.474, 9.879,
4.534, -5.908, -9.388, 2.356,
7.497, 8.186, -8.658, -4.796,
-8.248, -9.797, -7.907, -4.513,
3.469, 7.633, 7.244, -6.844,
-7.173, 4.450, 6.705, -7.035
]
q = torch.tensor(q_data, dtype=torch.bfloat16).reshape(bsz, heads, seq_len, dim).to("cuda")
k = q.clone() # K = Q
v = torch.randn((bsz, heads, seq_len, dim), dtype=torch.bfloat16).to("cuda")
attention_output = Xattention_prefill(
query_states=q,
key_states=k,
value_states=v,
stride=STRIDE,
block_size=BLOCK_SIZE,
use_triton=True,
chunk_size=4
)
your code:
Head 0 selected blocks: (0,1) (1,0)
Head 1 selected blocks: (0,1) (1,0)
mit-han-lab/x-attention:
Head 0 selected blocks: (0,0) (1,0) (1,1)
Head 1 selected blocks: (0,0) (1,0) (1,1)
So may I know what is the source of the reference results for your test?
Details:
Tickets: