Skip to content

Commit 6d33a6a

Browse files
test: Fix tests and add write starvation test case
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 7737041 commit 6d33a6a

File tree

2 files changed

+206
-5
lines changed

2 files changed

+206
-5
lines changed

kv_connectors/llmd_fs_backend/tests/test_fs_backend.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,19 @@ def make_gpu_specs(block_ids: list[int]) -> GPULoadStoreSpec:
7272

7373
def make_storage_specs(
7474
num_files: int,
75+
start_offset: int = 0,
7576
) -> tuple[SharedStorageLoadStoreSpec, list[BlockHash]]:
7677
"""Create SharedStorageLoadStoreSpec objects and their hashes for
77-
a given number of files."""
78-
ranges = [(100 + i * 100, 117 + i * 100) for i in range(num_files)]
78+
a given number of files.
79+
80+
Args:
81+
num_files: Number of file hashes to generate
82+
start_offset: Starting index for hash generation (prevents conflicts)
83+
"""
84+
ranges = [
85+
(100 + (start_offset + i) * 100, 117 + (start_offset + i) * 100)
86+
for i in range(num_files)
87+
]
7988
hashes = [get_prefix_hash(range(a, b)) for (a, b) in ranges]
8089
return SharedStorageLoadStoreSpec(hashes), hashes
8190

@@ -153,16 +162,43 @@ def wait_for(
153162
handler,
154163
job_id: int,
155164
timeout: float = 2.0,
165+
_finished_cache: dict = None,
156166
) -> bool:
157-
"""Wait for a specific job in handler.get_finished() up to timeout seconds."""
167+
"""
168+
Wait for a specific job in handler.get_finished() up to timeout seconds.
169+
170+
Args:
171+
handler: The handler object (put or get) to poll for finished jobs
172+
job_id: The specific job ID to wait for
173+
timeout: Max time to wait in seconds
174+
_finished_cache: Optional dict to cache finished jobs. Required when
175+
multiple handlers share the same engine, since get_finished() erases
176+
jobs from the map and we need to remember them across calls.
177+
178+
Returns:
179+
True if job succeeded, False if it failed
180+
"""
181+
# If no cache provided, create a local one (for backward compatibility)
182+
if _finished_cache is None:
183+
_finished_cache = {}
184+
185+
if job_id in _finished_cache:
186+
return _finished_cache[job_id]
187+
158188
start = time.time()
159189
while time.time() - start < timeout:
160190
finished = handler.get_finished()
191+
# Cache ALL finished jobs we see (important when handlers share an engine)
161192
for jid, ok in finished:
193+
_finished_cache[jid] = ok
162194
if jid == job_id:
163195
return ok
164196
time.sleep(0.01) # avoid busy-spin
165-
raise TimeoutError(f"Job {job_id} did not finish within {timeout}s")
197+
198+
raise TimeoutError(
199+
f"Job {job_id} did not finish within {timeout}s. "
200+
f"Cached jobs: {list(_finished_cache.keys())}"
201+
)
166202

167203

168204
def roundtrip_once(

kv_connectors/llmd_fs_backend/tests/test_priority_queue.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,168 @@ def pct(p):
374374
del handler, put, get
375375

376376

377+
def test_write_starvation_prevention(default_vllm_config):
378+
"""
379+
Test that writes aren't starved under continuous read pressure.
380+
381+
In a pure priority system, if reads keep arriving, writes might never
382+
execute. This test validates that the implementation has fairness
383+
guarantees to prevent write starvation.
384+
385+
Strategy:
386+
1. Submit 10 write operations
387+
2. Immediately start flooding queue with continuous reads
388+
3. Submit new read every 0.02s (faster than workers can process)
389+
4. Track write completion times
390+
5. Verify all writes complete within reasonable bounds
391+
392+
Expected behavior:
393+
- All writes should complete (no infinite starvation)
394+
- Write latency should stay bounded (e.g., <3s per write)
395+
- Writes should get ~20-30% of worker cycles despite read pressure
396+
"""
397+
threads_per_gpu = 4
398+
num_writes = 10
399+
read_submission_interval = 0.02
400+
max_acceptable_write_latency = 3.0
401+
402+
blocks_per_file = TEST_CONFIG["gpu_blocks_per_file"]
403+
num_read_files = 5
404+
num_blocks = (num_writes + num_read_files) * blocks_per_file
405+
406+
handler, context = create_test_handler(
407+
num_blocks=num_blocks,
408+
threads_per_gpu=threads_per_gpu,
409+
model_suffix="-starvation",
410+
)
411+
412+
file_mapper = context["file_mapper"]
413+
put = handler.gpu_to_storage_handler
414+
get = handler.storage_to_gpu_handler
415+
finished_cache = {}
416+
417+
# Prepare files for continuous reading
418+
read_block_ids = list(range(num_read_files * blocks_per_file))
419+
read_put_gpu = make_gpu_specs(read_block_ids)
420+
read_put_storage, read_hashes = make_storage_specs(num_read_files)
421+
cleanup_files(file_mapper, read_hashes)
422+
423+
put.transfer_async(job_id=0, spec=(read_put_gpu, read_put_storage))
424+
ok = wait_for(put, job_id=0, timeout=30.0, _finished_cache=finished_cache)
425+
assert ok, "Initial file preparation failed"
426+
427+
# Submit all writes at once
428+
write_offset = num_read_files
429+
write_block_start = num_read_files * blocks_per_file
430+
write_start_times = {}
431+
432+
for i in range(num_writes):
433+
block_ids = list(
434+
range(
435+
write_block_start + i * blocks_per_file,
436+
write_block_start + (i + 1) * blocks_per_file,
437+
)
438+
)
439+
write_gpu = make_gpu_specs(block_ids)
440+
write_storage, write_hashes = make_storage_specs(
441+
1, start_offset=write_offset + i
442+
)
443+
cleanup_files(file_mapper, write_hashes)
444+
445+
job_id = 1 + i
446+
write_start_times[job_id] = time.time()
447+
put.transfer_async(job_id=job_id, spec=(write_gpu, write_storage))
448+
449+
# Start continuous read flood
450+
read_job_counter = 1000
451+
reads_submitted = 0
452+
start_time = time.time()
453+
454+
# Continue flooding with reads until all writes complete
455+
write_jobs = set(range(1, num_writes + 1))
456+
write_completion_times = {}
457+
458+
while write_jobs:
459+
file_idx = reads_submitted % num_read_files
460+
block_ids = list(
461+
range(file_idx * blocks_per_file, (file_idx + 1) * blocks_per_file)
462+
)
463+
read_gpu = make_gpu_specs(block_ids)
464+
read_storage = SharedStorageLoadStoreSpec(
465+
[read_put_storage.block_hashes[file_idx]]
466+
)
467+
468+
get.transfer_async(job_id=read_job_counter, spec=(read_storage, read_gpu))
469+
reads_submitted += 1
470+
read_job_counter += 1
471+
472+
finished = put.get_finished()
473+
for job_id, ok in finished:
474+
if job_id in write_jobs:
475+
latency = time.time() - write_start_times[job_id]
476+
write_completion_times[job_id] = latency
477+
write_jobs.remove(job_id)
478+
finished_cache[job_id] = ok
479+
480+
# Timeout if writes don't complete
481+
elapsed = time.time() - start_time
482+
if elapsed > 30.0:
483+
raise TimeoutError(
484+
f"Write starvation detected! {len(write_jobs)} writes "
485+
f"did not complete after 30s under read pressure. "
486+
f"Submitted {reads_submitted} reads."
487+
)
488+
489+
time.sleep(read_submission_interval)
490+
491+
write_latencies = list(write_completion_times.values())
492+
max_write_latency = max(write_latencies)
493+
avg_write_latency = sum(write_latencies) / len(write_latencies)
494+
total_duration = time.time() - start_time
495+
496+
# Estimate throughput
497+
write_throughput = num_writes / total_duration
498+
read_throughput = reads_submitted / total_duration
499+
write_pct = write_throughput / (write_throughput + read_throughput) * 100
500+
501+
print(f"\n{'=' * 70}")
502+
print("Write Starvation Prevention Test Results")
503+
print(f"{'=' * 70}")
504+
print("Configuration:")
505+
print(f" Threads: {threads_per_gpu}")
506+
print(f" Writes submitted: {num_writes}")
507+
print(f" Reads submitted: {reads_submitted} (continuous flood)")
508+
print(f" Total duration: {total_duration:.2f}s")
509+
print("\nWrite Latencies:")
510+
print(f" Min: {min(write_latencies):.3f}s")
511+
print(f" Avg: {avg_write_latency:.3f}s")
512+
print(f" Max: {max_write_latency:.3f}s")
513+
print(f" Target: <{max_acceptable_write_latency}s")
514+
print("\nThroughput:")
515+
print(f" Writes: {write_throughput:.2f} ops/s")
516+
print(f" Reads: {read_throughput:.2f} ops/s")
517+
print(f" Write percentage: {write_pct:.1f}% (target >15%)")
518+
print(f"{'=' * 70}")
519+
520+
assert max_write_latency < max_acceptable_write_latency, (
521+
f"Write starvation detected! Max write latency {max_write_latency:.3f}s "
522+
f"exceeds {max_acceptable_write_latency}s under continuous read pressure."
523+
)
524+
525+
assert write_pct > 15, (
526+
f"Writes got only {write_pct:.1f}% of throughput under read pressure. "
527+
f"This indicates unfair scheduling - writes should get >15% even when "
528+
f"reads are continuously submitted."
529+
)
530+
531+
cleanup_files(file_mapper, read_hashes)
532+
for i in range(num_writes):
533+
_, write_hashes = make_storage_specs(1, start_offset=write_offset + i)
534+
cleanup_files(file_mapper, write_hashes)
535+
536+
del handler, put, get
537+
538+
377539
if __name__ == "__main__":
378540
import sys
379541

@@ -386,11 +548,14 @@ def pct(p):
386548
test_priority_completion_order()
387549
elif test_name == "percentiles":
388550
test_read_latency_percentiles()
551+
elif test_name == "starvation":
552+
test_write_starvation_prevention()
389553
else:
390554
print(f"Unknown test: {test_name}")
391-
print("Available tests: order, percentiles")
555+
print("Available tests: order, percentiles, starvation")
392556
else:
393557
print("Running all priority queue tests...")
394558
test_priority_completion_order()
395559
test_read_latency_percentiles()
560+
test_write_starvation_prevention()
396561
print("\nAll tests passed!")

0 commit comments

Comments
 (0)