Skip to content

Conversation

@lizhenyun01
Copy link
Collaborator

@lizhenyun01 lizhenyun01 commented Dec 30, 2025

Motivation

  • attention优化及重构V1 batch,当前只支持D节点,
  • 使用方式:export FD_ATTENTION_BACKEND=DECODE_APPEND_ATTN

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@codecov-commenter
Copy link

codecov-commenter commented Dec 30, 2025

Codecov Report

❌ Patch coverage is 29.57746% with 100 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/online/20251131@b018c49). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...ayers/attention/decode_append_attention_backend.py 21.90% 81 Missing and 1 partial ⚠️
...or/layers/attention/ops/decode_append_attention.py 55.55% 3 Missing and 1 partial ⚠️
...ers/attention/ops/decoder_write_cache_with_rope.py 55.55% 3 Missing and 1 partial ⚠️
...cutor/layers/attention/ops/config_for_attention.py 57.14% 2 Missing and 1 partial ⚠️
fastdeploy/platforms/cuda.py 0.00% 2 Missing and 1 partial ⚠️
fastdeploy/spec_decode/mtp.py 0.00% 1 Missing and 1 partial ⚠️
fastdeploy/worker/gpu_model_runner.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@                    Coverage Diff                     @@
##             release/online/20251131    #5833   +/-   ##
==========================================================
  Coverage                           ?   58.39%           
==========================================================
  Files                              ?      324           
  Lines                              ?    39333           
  Branches                           ?     5931           
==========================================================
  Hits                               ?    22969           
  Misses                             ?    14522           
  Partials                           ?     1842           
Flag Coverage Δ
GPU 58.39% <29.57%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@paddle-bot
Copy link

paddle-bot bot commented Dec 30, 2025

Thanks for your contribution!

const int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];

if (max_just_dec_len_this_time > 0) {
if (speculate_decoder) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个分支后续需要统一

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下一步这个算子会拆出来 统一去掉speculate_decoder分支

cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device));

dim3 grids(
sm_cout *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好可以预留一个设置位,如果没有设置就按照sm_count来,TBO下可能会用到。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前kerenl需要提前自适应搜索chunk 搜索依赖这个block数 没办法提前拿到 我理解可以以经验值来做支持

bool IsDynamicC8 = false>
__global__ void decode_append_attention_c8_kernel(
const __grid_constant__ AttentionParams<T, CacheT> params
// const __grid_constant__ CUtensorMap key_tensor_map,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TMA的代码不用删,入参保留,增加一个模板参数,kernel里通过constexpr判断是否走TMA加载,方便后续调试。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TMA代码涉及到sync操作的顺序 两种同时支持会引入比较多的开关 影响代码简洁 以注释形式保留了 调试可以打开注释做简单改动跑通

Copy link
Collaborator

@yongqiangma yongqiangma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heavengate heavengate merged commit 2e04b4e into PaddlePaddle:release/online/20251131 Jan 9, 2026
34 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants