Make phi model of webgpu ep always use long RoPE to improve tps performance for long context scenario#1930
Make phi model of webgpu ep always use long RoPE to improve tps performance for long context scenario#1930
Conversation
|
@kunal-vaishnavi @baijumeswani Could you please advise whether always using cos_cache_large and sin_cache_large would be an acceptable approach? |
The reasoning for why the KV caches need to be re-computed and the impact on quality can be found here. You should use both the small and large caches for the best output quality. Ideally, the KV cache re-computation should not be avoided here. This was a similar issue with the DML EP and the initial fix was to use just the large caches. However, quality issues soon emerged and a new fix was made. The small and large caches are combined into one tensor and the position ids are updated to index accordingly. The KV cache re-computation is skipped here, however. |
We found a perf regression issue for long context which is described in this ticket #1910. This PR proposed a fix to avoid the switch from short factor to long factor on webgpu ep of phi model to mitigate above issues caused by recomputation of position IDs and KV cache when switching to long factor. The experimental results from benchmark_e2e.py shows that it can generate the response tokens without perf regression for short(< 1000), middle(> 1000 and < 4097) and long(> 4097) sequence length. The fix includes two pieces:
cos_cache = cos_cache_largesin_cache = sin_cache_largePls note this is not final fix for the issue, just demonstrating a reasonable direction to discuss and move forward.
I have tested against this change with below command:
Convert model:
python3 builder.py -m /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct -o /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -p int4 -e webgpu --extra_options int4_accuracy_level=4 int4_algo_config=k_quant_lastRun beachmark like below commands:
benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 7936 -g 6000 --use_prompt_set -mobenchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 4096 -g 6000 --use_prompt_set -mobenchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 4096 -g 2000 --use_prompt_set -mobenchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 2048 -g 1000 --use_prompt_set -mobenchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 256 -g 300 --use_prompt_set -moPerf and Correctness Comparsion(Metrics: tps - average tokens generated per second, Original Solution means original model+generator, Proposed Solution means the updated model+generator by this pr)