Skip to content

Add XAttention reference operation#31864

Merged
vshampor merged 6 commits intoopenvinotoolkit:masterfrom
vshampor:xattn_reference
Oct 1, 2025
Merged

Add XAttention reference operation#31864
vshampor merged 6 commits intoopenvinotoolkit:masterfrom
vshampor:xattn_reference

Conversation

@vshampor
Copy link
Copy Markdown
Contributor

@vshampor vshampor commented Aug 25, 2025

Details:

  • Implements the XAttention operation using STL and OV reference operations.
  • Meant to be used for testing and debugging purposes only - does not bring performance benefits if used on its own. Actual HW-accelerated implementation would be using this code as reference.
  • [DO NOT MERGE] Example of the XAttention integration in the CPU plugin #31955 contains an example of integrating with the CPU plugin.

Tickets:

@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: CPU OpenVINO CPU plugin labels Aug 25, 2025
@vshampor vshampor force-pushed the xattn_reference branch 2 times, most recently from 9a119ee to e29e853 Compare September 1, 2025 14:26
@github-actions github-actions bot removed the category: CPU OpenVINO CPU plugin label Sep 3, 2025
@vshampor vshampor force-pushed the xattn_reference branch 2 times, most recently from 460a365 to 6be527b Compare September 3, 2025 13:12
@vshampor vshampor marked this pull request as ready for review September 3, 2025 13:13
@vshampor vshampor requested a review from a team as a code owner September 3, 2025 13:13
@yuxu42 yuxu42 requested a review from liubo-intel September 4, 2025 01:03
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference tests should be stored in:
src/plugins/template/tests/functional/op_reference/

as others.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aware of these tests. All of these use utilize unnecessary abstractions of OV opset ops and involve the sample plugin to do unit tests. Seeing that there will never be a counterpart to this operation in the OV opset, there is no reason to proliferate the bad design decision to have all reference ops tested through the template plugin just for the sake of consistency.

Comment on lines +72 to +77
OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size]
OPENVINO_ASSERT(out_shape.size() == 3);
OPENVINO_ASSERT(input_shape[0] == out_shape[0]);
OPENVINO_ASSERT(input_shape[1] % m_stride == 0);
OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]);
OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these assert. the shapes and input validation is done during shape inference in operator implementation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not relevant since there is no shape inference involved by design

* @param out_shape Shape of the output tensor data. Expected shape is strictly equal to
* `reshaped_qk_product_shape`.
*/
void softmax(const T* reshaped_qk_product_data,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is required as it is just call for ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2});

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reference operation. It is designed to be understandable and testable first and foremost. I provide a set of functions that each directly map to a phase or sub-phase of the intended HW-accelerated kernel. Softmax is one of these phases and therefore has a separate function. It also has an extra shape check, so having a function wrapper has some utility.

@vshampor vshampor requested a review from l-bat September 5, 2025 09:06
}
}

/** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so
/** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, so

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@vshampor vshampor requested a review from praasz September 24, 2025 09:20
@vshampor vshampor added this pull request to the merge queue Sep 25, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Sep 25, 2025
cumsum += index_and_largest_score.score;
retval[head_idx].insert(index_and_largest_score.idx);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vshampor It looks like this implementation selects blocks by globally sorting all blocks per head and accumulating their scores until the threshold is reached. In the X-Attention paper (Section 2.2, https://arxiv.org/abs/2503.16428
), the algorithm selects blocks per query block based on antidiagonal scoring. Could this global accumulation approach lead to differences from the paper’s intended block coverage for each query(https://github.com/mit-han-lab/x-attention/blob/e37988770b9d1bebd489eba011d615f35587ba08/xattn/src/utils.py#L44)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider the relative position of this function in the frontend select_blocks function and observe that the block selection is ultimately done among the antidiagonal subset of the full attention matrix. I didn't understand the concerns about globality - in which sense do you consider the sorting to be global and why does it not match to the original algo?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the problem, I use https://github.com/mit-han-lab/x-attention/tree/main?tab=readme -Use your test case data in ov file # quick use with your test date in

The selected blocks are inconsistent:

THRESHOLD = 0.8
BLOCK_SIZE = 2
STRIDE = 2

# batch=1, head=2, seq_len=4, dim=4
bsz = 1
heads = 2
seq_len = 4
dim = 4

q_data = [
     3.144,  8.512,  8.518, -8.386,
     7.889, -5.721,  5.507,  4.295,
    -6.624, -8.463,  7.474,  9.879,
     4.534, -5.908, -9.388,  2.356,

     7.497,  8.186, -8.658, -4.796,
    -8.248, -9.797, -7.907, -4.513,
     3.469,  7.633,  7.244, -6.844,
    -7.173,  4.450,  6.705, -7.035
]

q = torch.tensor(q_data, dtype=torch.bfloat16).reshape(bsz, heads, seq_len, dim).to("cuda")
k = q.clone()  # K = Q
v = torch.randn((bsz, heads, seq_len, dim), dtype=torch.bfloat16).to("cuda")

attention_output = Xattention_prefill(
    query_states=q,
    key_states=k,
    value_states=v,
    stride=STRIDE,
    block_size=BLOCK_SIZE,
    use_triton=True,
    chunk_size=4
)

your code:
Head 0 selected blocks: (0,1) (1,0)
Head 1 selected blocks: (0,1) (1,0)
mit-han-lab/x-attention:
Head 0 selected blocks: (0,0) (1,0) (1,1)
Head 1 selected blocks: (0,0) (1,0) (1,1)
So may I know what is the source of the reference results for your test?

@vshampor vshampor added this pull request to the merge queue Oct 1, 2025
Merged via the queue into openvinotoolkit:master with commit 175e5d7 Oct 1, 2025
206 checks passed
@vshampor vshampor deleted the xattn_reference branch October 1, 2025 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants