Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/ut/distributed/mooncake/test_kv_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import threading
import unittest
from types import SimpleNamespace

import torch

if not hasattr(torch, "npu"):
torch.npu = SimpleNamespace(Event=object) # type: ignore[attr-defined]

from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
LayerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerSendingThread,
KVCacheStoreSendingThread,
)


class _FakeKey:
def __init__(self, value: str):
self._value = value

def to_string(self) -> str:
return self._value


class _FakeStore:
def __init__(self, exists_result: list[int]):
self.exists_result = exists_result
self.put_calls: list[tuple[list[str], list[list[int]], list[list[int]]]] = []

def set_device(self):
return None

def exists(self, keys: list[str]) -> list[int]:
# Return exact number of states for requested keys.
return self.exists_result[: len(keys)]

def put(self, keys, addrs, sizes):
self.put_calls.append((list(keys), list(addrs), list(sizes)))


class _FakeTokenDatabase:
def process_tokens(self, token_len, block_hashes):
for i, _ in enumerate(block_hashes):
yield i * 16, (i + 1) * 16, _FakeKey(f"k{i}")

def prepare_value(self, start, end, block_ids):
block_id = start // 16
return [1000 + block_id], [end - start], block_id

def prepare_value_layer(self, start, end, block_ids, layer_id):
block_id = start // 16
return [2000 + layer_id * 100 + block_id], [end - start]


class TestKVTransferMissingKeyPut(unittest.TestCase):
def test_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
kv_role="kv_producer",
ready_event=threading.Event(),
enable_kv_event=False,
)

req_meta = ReqMeta(
req_id="req-1",
token_len_chunk=64,
block_ids=[0, 1, 2, 3],
block_hashes=[b"h0", b"h1", b"h2", b"h3"], # type: ignore[arg-type]
current_event=None,
)
thread.add_stored_request("req-1")
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)

self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[1001], [1003]])
self.assertEqual(put_sizes, [[16], [16]])

def test_layer_sending_thread_only_puts_missing_keys(self):
store = _FakeStore(exists_result=[1, 0, 1, 0])
token_db = _FakeTokenDatabase()
thread = KVCacheStoreLayerSendingThread(
m_store=store,
token_database=token_db,
block_size=16,
tp_rank=0,
dcp_size=1,
put_step=1,
ready_event=threading.Event(),
num_layers=2,
enable_kv_event=False,
)

req_meta = LayerMultiBlockReqMeta(
req_id="req-2",
keys=[_FakeKey("k0"), _FakeKey("k1"), _FakeKey("k2"), _FakeKey("k3")], # type: ignore[arg-type]
starts=[0, 16, 32, 48],
ends=[16, 32, 48, 64],
block_ids=[0, 1, 2, 3],
layer_id=1,
is_last_chunk=False,
current_event=None,
)
thread.request_queue.put(req_meta)
thread._handle_request(req_meta)

self.assertEqual(len(store.put_calls), 1)
put_keys, put_addrs, put_sizes = store.put_calls[0]
self.assertEqual(put_keys, ["k1", "k3"])
self.assertEqual(put_addrs, [[2101], [2103]])
self.assertEqual(put_sizes, [[16], [16]])


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def add_request(self, req_meta: ReqMeta) -> None:


@dataclass
class LasyerMultiBlockReqMeta:
class LayerMultiBlockReqMeta:
req_id: str
keys: list[LayerPoolKey]
starts: list[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
ChunkedTokenDatabase,
LasyerMultiBlockReqMeta,
LayerMultiBlockReqMeta,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change corrects a typo in the import (LasyerMultiBlockReqMeta -> LayerMultiBlockReqMeta). However, the corresponding class definition in vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py still has the typo (class LasyerMultiBlockReqMeta). This will lead to an ImportError. Please ensure you also correct the class name at its definition to LayerMultiBlockReqMeta in vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py to fix this.

ReqMeta,
)
# isort: on
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(

def add_request(
self,
request: ReqMeta | LasyerMultiBlockReqMeta,
request: ReqMeta | LayerMultiBlockReqMeta,
) -> torch.Tensor:
self.request_queue.put(request)

Expand Down Expand Up @@ -88,22 +88,22 @@ def _handle_request(self, req_meta: Any):
def lookup(
self,
keys: list[str],
) -> int:
) -> list[bool]:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
Check the existence of all keys from the cache engine.
:return: A bool list where True means the key exists in store.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
exists_list = [False] * len(keys)
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
if index >= len(exists_list):
break
exists_list[index] = value == 1
return exists_list
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)
return [False] * len(keys)

def update_kv_event(self, event: list[BlockStored]):
with self.kv_event_lock:
Expand Down Expand Up @@ -159,39 +159,44 @@ def _handle_request(self, req_meta: ReqMeta):
starts = []
ends = []
keys = []
block_hashes = []
if req_id not in self.stored_requests:
self.request_queue.task_done()
return

for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
for index, (start, end, key) in enumerate(self.token_database.process_tokens(token_len, req_meta.block_hashes)):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[index])

if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step :: self.put_step]
block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step]

if not keys:
self.dec_stored_request(req_id)
return

skip_block_num = self.lookup(keys)
exists_states = self.lookup(keys)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]

if skip_block_num == len(keys):
if not missing_indices:
self.dec_stored_request(req_id)
return

starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
keys = [keys[index] for index in missing_indices]
block_hashes = [block_hashes[index] for index in missing_indices]
Comment on lines +184 to +193
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation first creates a list of missing indices and then iterates over it four times to filter other lists. This can be optimized for performance and readability by creating the filtered lists in a single pass.

Suggested change
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(keys):
if not missing_indices:
self.dec_stored_request(req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
keys = [keys[index] for index in missing_indices]
block_hashes = [block_hashes[index] for index in missing_indices]
missing_data = [
(starts[i], ends[i], keys[i], block_hashes[i])
for i, exists in enumerate(exists_states) if not exists
]
if not missing_data:
self.dec_stored_request(req_id)
return
starts, ends, keys, block_hashes = map(list, zip(*missing_data))


logger.debug(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
len(missing_indices),
req_id,
)

Expand All @@ -206,7 +211,7 @@ def _handle_request(self, req_meta: ReqMeta):
sizes = []
stored_events: list[BlockStored] = []
prev_key = None
new_block_hashes = [maybe_convert_block_hash(bh) for bh in req_meta.block_hashes[skip_block_num:]]
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr)
Expand Down Expand Up @@ -307,7 +312,7 @@ def add_request( # type: ignore[override]
self.request_queue.put(req_meta)

def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
):
starts = req_meta.starts
ends = req_meta.ends
Expand All @@ -330,16 +335,17 @@ def _handle_request( # type: ignore[override]
for key in keys:
key_list.append(key.to_string())

skip_block_num = self.lookup(key_list)
exists_states = self.lookup(key_list)
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]

if skip_block_num == len(key_list):
if not missing_indices:
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return

starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
key_list = [key_list[index] for index in missing_indices]
Comment on lines +339 to +348
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the other _handle_request method, this section reconstructs lists by iterating over missing_indices multiple times. This can be refactored for better performance and readability.

Suggested change
missing_indices = [index for index, exists in enumerate(exists_states) if not exists]
if skip_block_num == len(key_list):
if not missing_indices:
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
starts = [starts[index] for index in missing_indices]
ends = [ends[index] for index in missing_indices]
key_list = [key_list[index] for index in missing_indices]
missing_data = [
(starts[i], ends[i], key_list[i])
for i, exists in enumerate(exists_states) if not exists
]
if not missing_data:
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts, ends, key_list = map(list, zip(*missing_data))


addr_list = []
size_list = []
Expand All @@ -359,10 +365,10 @@ def _handle_request( # type: ignore[override]
self.request_queue.task_done()

logger.info(
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
len(keys),
"Storing KV cache for %d out of %d blocks (missing_count=%d) for request %s",
len(key_list),
total_block,
skip_block_num,
len(missing_indices),
req_meta.req_id,
)

Expand All @@ -384,12 +390,12 @@ def __init__(
self.get_event = get_event

def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta)

def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta
self, req_meta: LayerMultiBlockReqMeta
):
addr_list = []
size_list = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
AscendConnectorMetadata,
ChunkedTokenDatabase,
KeyMetadata,
LasyerMultiBlockReqMeta,
LayerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
Expand Down Expand Up @@ -399,7 +399,7 @@ def retrieve_layer(
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(
req_meta = LayerMultiBlockReqMeta(
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
Expand Down Expand Up @@ -455,7 +455,7 @@ def store_layer(
if keys:
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(
req_meta = LayerMultiBlockReqMeta(
request.req_id,
keys_multi_chunk,
starts,
Expand Down
Loading