Skip to content

[FEAT][Ops] add paged KV-cache GQA prefill #1099

@superAngGao

Description

@superAngGao

Parent: #1096

Related feature issues: #1097, #1098, #1100, #1101, #1102, #1103

Scope

Track paged KV-cache GQA prefill support.

This issue covers the public paged cache-aware prefill OP and its manifest contract. It requires a manifest entry because the OP exposes packed THD current chunks, caller-owned physical page storage, block_table, page-size metadata, and in-place append semantics.

Public OP

GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp

Purpose:

  • Packed current chunk THD layout.
  • Paged physical KV cache storage.
  • block_table[b, logical_page] -> physical_page mapping.
  • Per-request cache_seqlens before append.
  • Reads old KV by gathering physical pages.
  • Reads current chunk from k_new/v_new.
  • Computes attention over old prefix + current chunk.
  • Appends current K/V into caller-owned physical pages.
  • Does not allocate pages and does not update cache_seqlens.

Manifest Change Plan

This issue requires manifest changes.

Required manifest entry:

  • Add GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp.

The entry must declare:

  • family: attention
  • ref_api: none unless a reference exactly matches TileOps paged cache update semantics
  • ordered signature.inputs, signature.outputs, and signature.params
  • packed THD + flattened physical page shape rules
  • workloads with page-size and block-table dimensions
  • roofline
  • source.kernel, source.kernel_map, source.op, source.test, and source.bench

Manifest Field Sketch

This is not copy-paste YAML; it is the field contract the manifest PR should encode using the repo's ordered-dict schema.

Manifest key: GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp

Inputs:

  • q: dtype float16 | bfloat16, shape [T_q, H, D]
  • k_new: dtype same_as(q), shape [T_q, H_kv, D]
  • v_new: dtype same_as(q), shape [T_q, H_kv, D]
  • k_pages: dtype same_as(q), shape [P_tokens, H_kv, D]
  • v_pages: dtype same_as(q), shape [P_tokens, H_kv, D]
  • cu_seqlens_q: dtype int32, shape [N_cu]
  • cache_seqlens: dtype int32, shape [B]
  • block_table: dtype int32, shape [B, max_pages_per_req]

Output:

  • o: dtype same_as(q), shape [T_q, H, D]

Params:

Softcap semantics must match #1101: softcap: null and softcap: 0 both disable score capping and preserve existing behavior; softcap > 0 applies score capping before softmax; negative values are rejected by OP runtime validation.

Shape rules:

  • H % H_kv == 0
  • v_new.shape == k_new.shape
  • k_new.shape[0] == q.shape[0]
  • v_pages.shape == k_pages.shape
  • P_tokens % page_size == 0
  • N_cu == B + 1
  • cu_seqlens_q[0] == 0 and cu_seqlens_q[-1] == T_q are runtime validation rules
  • cache_seqlens + q_lens <= max_pages_per_req * page_size is runtime validation
  • block_table entries must be valid physical page ids

Proposed source:

  • source.kernel: tileops/kernels/attention/gqa_fwd.py
  • source.kernel_map:
  • source.op: tileops/ops/attention/gqa.py
  • source.test: tests/ops/attention/test_gqa_prefill_paged.py
  • source.bench: benchmarks/ops/attention/bench_gqa.py or a dedicated paged benchmark file if introduced

Page Layout Contract

Initial paged layout uses flattened physical page storage:

physical token id = physical_page * page_size + page_offset
k_pages/v_pages shape = [P_tokens, H_kv, D]
P_tokens % page_size == 0

The OP consumes caller-owned storage and caller-provided block_table; it does not allocate pages and does not update cache_seqlens.

Page sizes should be powers of two. The first implementation should cover the release page sizes selected by the implementation PR and benchmark plan, with page-boundary correctness tests.

Workloads / Roofline

Manifest workloads should use <tensor_name>_shape, dtypes, and label according to docs/manifest.md.

Example workload fields:

- {
    q_shape: [512, 32, 128],
    k_new_shape: [512, 8, 128],
    v_new_shape: [512, 8, 128],
    k_pages_shape: [4608, 8, 128],
    v_pages_shape: [4608, 8, 128],
    cu_seqlens_q_shape: [2],
    cache_seqlens_shape: [1],
    block_table_shape: [1, 72],
    max_pages_per_req: 72,
    page_size: 64,
    max_seqlen_q: 512,
    is_causal: true,
    dtypes: [float16],
    label: "llama-3.1-8b-prefill-paged-pg64-4k-old-512-new",
  }

The initial benchmark generator should derive cache_seqlens = max_pages_per_req * page_size - max_seqlen_q and construct a valid block_table covering old prefix plus current chunk unless the PR defines another manifest-backed convention.

Roofline should account for:

  • Q read
  • current K/V read
  • old page K/V gather reads
  • appended page K/V writes
  • output write
  • block_table / cu_seqlens / cache_seqlens metadata reads
  • attention QK and ScoreV/PV FLOPs using per-request lengths

Validation Expectations

The corresponding PR should show:

  • output correctness with paged old cache
  • append correctness into physical pages
  • old page prefix remains unchanged
  • batch with different page tables
  • old length not page-aligned
  • append across page boundaries
  • multiple page sizes
  • invalid page id validation
  • capacity overflow validation
  • fp16 / bf16 smoke coverage

PR Notes

Implementation details should be discussed in the PR that references this issue.
Do not close #1096 from that PR.

Metadata

Metadata

Assignees

Labels

featureNew feature or new operator

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions