[0.18.0][Bugfix][KV Pool]Fix KV transfer put logic #7718
[0.18.0][Bugfix][KV Pool]Fix KV transfer put logic #7718Pz1116 wants to merge 4 commits intovllm-project:releases/v0.18.0from
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request improves the efficiency of the KV pool transfer mechanism by replacing a sequential block-check approach with a more precise lookup that identifies and transfers only missing blocks. This change prevents redundant data transfers and reduces unnecessary logging in the master service. Additionally, the PR includes minor refactoring to fix naming inconsistencies and adds unit tests to ensure the robustness of the new transfer logic. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
Signed-off-by: Pz1116 <zpbzpb123123@gmail.com> Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com>
f72bc7d to
4136ef1
Compare
There was a problem hiding this comment.
Code Review
Suggested PR Title:
[Distributed][BugFix] Fix KV cache lookup and correct LayerMultiBlockReqMeta typoSuggested PR Summary:
### What this PR does / why we need it?
This PR refactors the KV cache lookup mechanism to return a boolean existence list, allowing for the identification and storage of specific missing blocks instead of stopping at the first missing one. It also corrects a typo in the `LayerMultiBlockReqMeta` class name. Review feedback points out a potential `ImportError` if the class definition itself isn't updated and suggests optimizing list filtering logic using `zip` and `map` for better performance.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests were added in `tests/ut/distributed/mooncake/test_kv_transfer.py` to verify the new lookup and storage logic.| from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( | ||
| ChunkedTokenDatabase, | ||
| LasyerMultiBlockReqMeta, | ||
| LayerMultiBlockReqMeta, |
There was a problem hiding this comment.
This change corrects a typo in the import (LasyerMultiBlockReqMeta -> LayerMultiBlockReqMeta). However, the corresponding class definition in vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py still has the typo (class LasyerMultiBlockReqMeta). This will lead to an ImportError. Please ensure you also correct the class name at its definition to LayerMultiBlockReqMeta in vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py to fix this.
| missing_indices = [index for index, exists in enumerate(exists_states) if not exists] | ||
|
|
||
| if skip_block_num == len(keys): | ||
| if not missing_indices: | ||
| self.dec_stored_request(req_id) | ||
| return | ||
|
|
||
| starts = starts[skip_block_num:] | ||
| ends = ends[skip_block_num:] | ||
| keys = keys[skip_block_num:] | ||
| starts = [starts[index] for index in missing_indices] | ||
| ends = [ends[index] for index in missing_indices] | ||
| keys = [keys[index] for index in missing_indices] | ||
| block_hashes = [block_hashes[index] for index in missing_indices] |
There was a problem hiding this comment.
The current implementation first creates a list of missing indices and then iterates over it four times to filter other lists. This can be optimized for performance and readability by creating the filtered lists in a single pass.
| missing_indices = [index for index, exists in enumerate(exists_states) if not exists] | |
| if skip_block_num == len(keys): | |
| if not missing_indices: | |
| self.dec_stored_request(req_id) | |
| return | |
| starts = starts[skip_block_num:] | |
| ends = ends[skip_block_num:] | |
| keys = keys[skip_block_num:] | |
| starts = [starts[index] for index in missing_indices] | |
| ends = [ends[index] for index in missing_indices] | |
| keys = [keys[index] for index in missing_indices] | |
| block_hashes = [block_hashes[index] for index in missing_indices] | |
| missing_data = [ | |
| (starts[i], ends[i], keys[i], block_hashes[i]) | |
| for i, exists in enumerate(exists_states) if not exists | |
| ] | |
| if not missing_data: | |
| self.dec_stored_request(req_id) | |
| return | |
| starts, ends, keys, block_hashes = map(list, zip(*missing_data)) |
| missing_indices = [index for index, exists in enumerate(exists_states) if not exists] | ||
|
|
||
| if skip_block_num == len(key_list): | ||
| if not missing_indices: | ||
| if is_last_chunk and layer_id == self.final_layer_id: | ||
| self.set_finished_request(req_meta.req_id) | ||
| return | ||
|
|
||
| starts = starts[skip_block_num:] | ||
| ends = ends[skip_block_num:] | ||
| key_list = key_list[skip_block_num:] | ||
| starts = [starts[index] for index in missing_indices] | ||
| ends = [ends[index] for index in missing_indices] | ||
| key_list = [key_list[index] for index in missing_indices] |
There was a problem hiding this comment.
Similar to the other _handle_request method, this section reconstructs lists by iterating over missing_indices multiple times. This can be refactored for better performance and readability.
| missing_indices = [index for index, exists in enumerate(exists_states) if not exists] | |
| if skip_block_num == len(key_list): | |
| if not missing_indices: | |
| if is_last_chunk and layer_id == self.final_layer_id: | |
| self.set_finished_request(req_meta.req_id) | |
| return | |
| starts = starts[skip_block_num:] | |
| ends = ends[skip_block_num:] | |
| key_list = key_list[skip_block_num:] | |
| starts = [starts[index] for index in missing_indices] | |
| ends = [ends[index] for index in missing_indices] | |
| key_list = [key_list[index] for index in missing_indices] | |
| missing_data = [ | |
| (starts[i], ends[i], key_list[i]) | |
| for i, exists in enumerate(exists_states) if not exists | |
| ] | |
| if not missing_data: | |
| if is_last_chunk and layer_id == self.final_layer_id: | |
| self.set_finished_request(req_meta.req_id) | |
| return | |
| starts, ends, key_list = map(list, zip(*missing_data)) |
Signed-off-by: Pz1116 <zpbzpb123123@gmail.com>
Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com> Signed-off-by: Pz1116 <zpbzpb123123@gmail.com>
Co-authored-by: DreamerLeader <2270923832@qq.com> Co-authored-by: fems14 <1804143737@qq.com> Signed-off-by: Pz1116 <zpbzpb123123@gmail.com>
What this PR does / why we need it?
Before when we do put for KV Pool, we find the first non-existing key and put all the blocks starting from that index; however, if the prefix cache blocks is from another request, and some of the blocks are evicted due to LRU, we will be putting blocks that still exist in the pool, and causing MooncakeStore printing unnecessary logs in master service.
What this PR does:
Now we lookup all the keys and only put the ones that are missing.
Fix lookup_scheduler in pool_worker so it handles GQA correctly.
Fixes a few existing typos
Add UT, written by codex
Does this PR introduce any user-facing change?
How was this patch tested?