Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,10 @@ void Generator::GenerateNextToken() {
// at this stage which is achieved by rewinding to zero and appending the current sequence
// Scenarios where this solution works: Batch size = 1, Num beams = 1, decoder model, EP is either CPU or CUDA
// Scenarios where it doesn't work: Batch size > 1 OR Num beams > 1 OR Multimodal model (like phi3 vision) OR EP is DML
// Skip this logic for WebGPU EP as it already uses long RoPE scaling graph for all sequence lengths
if (search_->params_->BatchBeamSize() == 1 && !epUsesSingleRopeFactor) {
if (((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8193) && (model_->config_->model.type == "phi3small"))) {
if (model_->p_device_->GetType() != DeviceType::WEBGPU &&
(((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8193) && (model_->config_->model.type == "phi3small")))) {
auto current_seq = cpu_span<int32_t>(GetSequence(0).CopyDeviceToCpu());
RewindToLength(0);
AppendTokens(current_seq);
Expand Down
12 changes: 10 additions & 2 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,8 +1921,6 @@ def make_rotary_embedding_multi_cache(self, **kwargs):

# Determine which EPs don't support the If operator
self.eps_without_if_support = ["dml"]
if self.extra_options.get("enable_webgpu_graph", False):
self.eps_without_if_support.append("webgpu")

if self.ep in self.eps_without_if_support:
cos_cache = torch.cat((cos_cache_small, cos_cache_large), dim=0)
Expand All @@ -1932,6 +1930,16 @@ def make_rotary_embedding_multi_cache(self, **kwargs):
self.make_initializer(sin_cache, sin_cache_name)
# Do NOT make the subgraph with the If node for DML EP.
return

# WebGPU: Always use large caches to avoid the short to long factor switch
# if there is no correctness issue for all lengths of tokens
if self.ep == "webgpu":
cos_cache = cos_cache_large
sin_cache = sin_cache_large
# Save cos/sin caches to disk
self.make_initializer(cos_cache, cos_cache_name)
self.make_initializer(sin_cache, sin_cache_name)
return

# TRT-RTX: Apply padding and create split If nodes with early return
if self.ep == "trt-rtx":
Expand Down
Loading