Add Qwen3.5 hybrid decoder export support (GatedDeltaNet + Attention)#2043
Add Qwen3.5 hybrid decoder export support (GatedDeltaNet + Attention)#2043apsonawane wants to merge 9 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds ONNX export support in the Python model builder for Qwen3.5 hybrid decoder models (alternating KV-cache attention layers and GatedDeltaNet recurrent/conv linear-attention layers), integrating with the existing onnxruntime-genai builder + config pipeline.
Changes:
- Introduces
Qwen35TextModelbuilder implementing full-attention + linear-attention layer graph construction, plus hybrid cache I/O (KV + conv/recurrent state). - Exposes the new builder from
builders/__init__.pyand wires up dispatch inbuilder.pyforQwen3_5ForConditionalGeneration. - Minor cleanups in
builder.pyoption parsing / formatting.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/python/py/models/builders/qwen.py |
Adds the Qwen3.5 hybrid decoder builder implementation, rotary cache generation, causal mask construction, Scan recurrence body, and config patching. |
src/python/py/models/builders/__init__.py |
Re-exports Qwen35TextModel. |
src/python/py/models/builder.py |
Dispatches Qwen3.5 HF architecture to Qwen35TextModel and makes small option/formatting tweaks. |
|
|
||
|
|
||
| import json | ||
| import os |
There was a problem hiding this comment.
We shouldn't need to import json and os here to patch issues with the produced files. We can fix their contents before they are serialized to disk.
There was a problem hiding this comment.
Done. make_genai_config now temporarily adjusts self attributes (num_layers, model_type, past_present_share_buffer, KV cache template keys) so the base class produces the correct config directly — no post-hoc patching.
| del store[f"present.{suffix}"] | ||
|
|
||
| # Build per-layer cache I/O | ||
| for i in range(self.num_layers): |
There was a problem hiding this comment.
Why are we using new naming formats for the full attention layers? We can re-use the existing formats and remove the per-layer entries that aren't needed.
There was a problem hiding this comment.
Fixed. The full-attention layers now reuse the base class's existing KV cache naming format [past_key_values.{i}.key] / [present.{i}.key]— we just filter the template lists to exclude linear-attention layer indices instead of deleting and recreating them. Also removed the redundant is_layer and has_final_norm overrides (base class handles both), and restructured make_layer/make_attention to follow the standard dispatch pattern used by other models.
| """Build one decoder layer. Dispatches to full attention or | ||
| GatedDeltaNet linear attention based on self.layer_types.""" | ||
|
|
||
| if self.layer_types[layer_id] == "linear_attention": |
There was a problem hiding this comment.
We should ideally handle the alternating attention layer types by using the same logic that is done for other models: override the make_attention method and dispatch to the right make method based on the layer id.
There was a problem hiding this comment.
Done. Added a [make_attention] override that dispatches to [_make_full_attention] or [_make_linear_attention] based on [self.layer_types[layer_id]]. The [make_layer] override only exists because linear-attention layers use [layer.linear_attn] instead of [layer.self_attn]— it picks the right module and delegates everything else (layernorm, MLP) to the base class pattern. The attention methods only handle the attention-specific subgraph and end by setting [self.layernorm_attrs["skip_input"]].
| q_norm_name = f"/model/layers.{layer_id}/attn/q_norm/RMSNormalization" | ||
| q_norm_output = f"{q_norm_name}/output_0" | ||
| self.make_node( | ||
| "RMSNormalization", |
There was a problem hiding this comment.
Have we verified that the ONNX RMSNormalization op has the same performance as the ORT SimplifiedLayerNormalization op? Can we also check that other IHVs have implemented this op for their EPs? Otherwise, they will likely modify this code to use SimplifiedLayerNormalization instead. I would like to avoid EP-specific logic as much as possible.
There was a problem hiding this comment.
Yes true, RMSNormalization is a standard op. Replaced it with SimplifiedLayerNormalization
| self.attention_attrs["k_path"] = f"{k_reshape_2}/output_0" | ||
|
|
||
| def _make_causal_mask(self): | ||
| """Build causal attention mask [B, 1, S, total_S] for Attention op. |
There was a problem hiding this comment.
Why is the existing causal mask that gets built insufficient?
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 4037 to 4039 in e8c80b6
There was a problem hiding this comment.
Qwen3.5 uses standard domain Attention Op which expects float mask. So created a new function for that.
| basename = "/model/causal_mask" | ||
| attn_mask = self.input_names["attention_mask"] # [B, total_S] | ||
|
|
||
| # Constants |
There was a problem hiding this comment.
Constant nodes have a special naming format. The make_constant method assists with this. Can we use that instead?
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 883 to 895 in e8c80b6
| ) | ||
|
|
||
| # Step 2: Get query length S from inputs_embeds shape | ||
| embeds_input = self.input_names.get("inputs_embeds", "inputs_embeds") |
There was a problem hiding this comment.
Rather than having default names if the input name does not exist in the dictionary, let's access self.input_names["inputs_embeds"] directly. If an error occurs, that is an indication that something should have been changed earlier.
| def _get_shared_q_scale(self, head_dim): | ||
| """Return the name of a shared 1/sqrt(head_dim) constant (created once).""" | ||
| name = "/model/constants/q_scale" | ||
| if name not in self.node_names: |
There was a problem hiding this comment.
There's no need to do these checks. These node names are automatically checked for duplicates when inserted.
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 806 to 816 in e8c80b6
There was a problem hiding this comment.
We are not calling make_node which checks for duplicates. But make_initializer calls make_value which uses setdefault which returns the exisiting one so the dedupe is not necessary removing the condition
| def _get_shared_l2_eps(self): | ||
| """Return the name of a shared L2 epsilon constant (created once).""" | ||
| name = "/model/constants/l2_eps" | ||
| if name not in self.node_names: |
There was a problem hiding this comment.
We are not calling make_node which checks for duplicates. But make_initializer calls make_value which uses setdefault which returns the exisiting one so the dedupe is not necessary removing the condition
| ) | ||
| gc_name = f"{basename}/{suffix}/dim{dim_idx}/Gather_cache" | ||
| gc_out = f"{gc_name}/output_0" | ||
| self.make_node("Gather", [cache_name, sq_out], [gc_out], name=gc_name, axis=0) |
There was a problem hiding this comment.
Can we use make_gather here?
There was a problem hiding this comment.
Updated to use make_gather
| ) | ||
| return unflat_out | ||
|
|
||
| def _make_l2_normalize(self, basename, input_name, last_dim, leading_dims=None): |
There was a problem hiding this comment.
Can we use the ReduceL2 op instead of these primitive ops?
There was a problem hiding this comment.
Updated, using ReduceL2
| # Rename reshape output to target present state name | ||
| if present_state_name: | ||
| self.make_node( | ||
| "Identity", |
There was a problem hiding this comment.
Same question here
There was a problem hiding this comment.
Removed Identity
| def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType: | ||
| int4_cpu = precision == "int4" and execution_provider == "cpu" | ||
| fp32_webgpu = execution_provider == "webgpu" and extra_options.get("use_webgpu_fp32", False) | ||
| bf16_cuda = precision == "int4" and execution_provider in {"cuda", "trt-rtx"} and extra_options.get("use_cuda_bf16", False) |
There was a problem hiding this comment.
Can we keep the spacings in this file as is?
There was a problem hiding this comment.
I just ran lintrunner, that adjusted this
There was a problem hiding this comment.
Fixed the spacings
| outputs=(), | ||
| nodes=(), | ||
| opset_imports={"": 21, "com.microsoft": 1}, | ||
| opset_imports={"": 23, "com.microsoft": 1}, |
There was a problem hiding this comment.
Can we check which ops change their signature from 21 to 23 so that we don't break other models?
|
|
||
| # 3D position_ids for mRoPE: [3, batch_size, sequence_length] | ||
| self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"] | ||
| if "position_ids" not in self.input_names: |
There was a problem hiding this comment.
I believe position_ids is always in the input names dictionary by default now. Can you check this and remove this if block if it always is?
| self.mrope_rotary_dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size) | ||
|
|
||
| # Force RoPE computation in float32 for numerical stability | ||
| if "rope_cast" not in self.attention_attrs: |
There was a problem hiding this comment.
Given self.rope_attrs exists, can the casting setting be moved to that dictionary instead?
| self.attention_attrs["rope_cast"]["use_fp32"] = True | ||
|
|
||
| # Pre-compute cos/sin cache tables and interleaving masks for mRoPE | ||
| self._make_rotary_caches() |
There was a problem hiding this comment.
There is a method called make_rotary_embedding_caches_from_scratch which already does this and should be overwritten. However, given how much AI-generated code has been added into the model builder lately, all of it will need a refactoring. This is fine for now as I will clean this up later.
| self.attention_attrs["q_norm"] = True | ||
| self.attention_attrs["k_norm"] = True | ||
| # Disable packed matmul since Q projection is doubled (4096 vs normal 2048) | ||
| self.attention_attrs["use_packed_matmul"] = False |
There was a problem hiding this comment.
Why does packed MatMul have to be disabled for this reason? Even if the hidden sizes for Q, K, V differ, the packed MatMul should still work. The use of QK norm is a better reason for why this should be disabled.
There was a problem hiding this comment.
Yeah this is not required, The base class already disables packed_matmul when q_norm is true
| # full-attention layer indices. The base class creates entries for | ||
| # all num_layers — we remove linear-attention layers that use | ||
| # conv/recurrent state instead of KV cache. | ||
| attn_indices = {i for i, lt in enumerate(self.layer_types) if lt == "full_attention"} |
There was a problem hiding this comment.
If self.layer_types exists, why do we need to calculate attn_indices separately? I think it would be better to loop through self.layer_types and merge both of the following for-loops together.
There was a problem hiding this comment.
Merged into a single loop over self.layer_types. Each iteration either keeps the KV cache entry (full_attention) or adds conv/recurrent state entries (linear_attention). Eliminated the separate attn_indices set and the second loop.
| self.linear_value_head_dim, | ||
| ] | ||
|
|
||
| self.output_names[f"present_state.{i}.conv"] = f"present.{i}.conv_state" |
There was a problem hiding this comment.
Let's make sure that the past/present naming schemas match the same schemas that will be used in the LFM2 models so that we don't have to duplicate work across both PRs.
| self.make_node("Mul", [attn_output, sigmoid_output], [gated_output], name=gated_name) | ||
| self.make_value(gated_output, self.io_dtype, ["batch_size", "sequence_length", q_size]) | ||
|
|
||
| # 8. Output projection |
There was a problem hiding this comment.
I think we can re-use make_attention_output_proj here.
onnxruntime-genai/src/python/py/models/builders/base.py
Lines 2944 to 2966 in 07d21ba
| # SimplifiedLayerNormalization (com.microsoft, no offset for gated norm) | ||
| norm_name = f"{basename}/SimplifiedLayerNormalization" | ||
| norm_output = f"{norm_name}/output_0" | ||
| self.make_node( |
There was a problem hiding this comment.
Let's use make_layernorm so that the right domain for the op gets registered.
There was a problem hiding this comment.
make_layernorm can't be used here because it's tightly coupled to the decoder layer pipeline (expects [B, S, hidden_size] shapes). The gated RMSNorm operates on [B*S*N, hv] tensors with no skip connection. To use make_layernorm we need a major refactor instead this is simpler. We can add domain here so that it is registered
| # if self.ep in {"cuda"}: | ||
| # self._make_linear_attention_fused(layer_id, linear_attn, root_input) | ||
| # else: |
Check notice
Code scanning / CodeQL
Commented-out code Note
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI about 18 hours ago
To fix the problem, remove the commented-out control-flow code and keep only active, real code plus an explanatory TODO. This eliminates misleading “dead” logic while retaining documentation of the intended future change.
Concretely, in src/python/py/models/builders/qwen.py, in the _make_linear_attention method around lines 1458–1462, delete the commented if self.ep in {"cuda"}: / _make_linear_attention_fused / else: lines, leaving the TODO and the active call to _make_linear_attention_decomposed. No new imports or definitions are needed, and existing functionality remains unchanged because the decomposed path is already the only one executed.
| @@ -1456,9 +1456,6 @@ | ||
| ``LinearAttention`` ops once ORT registers kernels for them. | ||
| """ | ||
| # TODO: Enable fused path when ORT has CausalConvWithState/LinearAttention kernels | ||
| # if self.ep in {"cuda"}: | ||
| # self._make_linear_attention_fused(layer_id, linear_attn, root_input) | ||
| # else: | ||
| self._make_linear_attention_decomposed(layer_id, linear_attn, root_input) | ||
|
|
||
| def _make_linear_attention_fused(self, layer_id, linear_attn, root_input): |
This PR adds support for exporting Qwen3.5 hybrid decoder models to ONNX via the
onnxruntime-genai model builder. Qwen3.5 uses a novel hybrid architecture that
alternates between standard attention layers (with KV cache) and GatedDeltaNet
linear attention layers (with conv_state + recurrent_state).
Architecture
Qwen3.5 has two layer types controlled by
full_attention_interval:per-head QK OffsetRMSNorm, partial interleaved mRoPE, opset 23 Attention op
conv1d, L2-normalized Q/K, Scan-based recurrence, gated RMSNorm
Key design decisions
RMSNormalization,RotaryEmbedding,Attention) instead ofcom.microsoftcontrib ops for portability[B, 1, S, total_S]shape — correctlyhandles both prefill and single-token decode modes
efficient mRoPE at runtime
conv_state + recurrent_state for GatedDeltaNet layers
weight + 1offset(reuses base class)
Changes
src/python/py/models/builders/qwen.pyQwen35Modelclass with:_make_full_attention_layer: Q/gate split, QK norm, mRoPE, Attention, gating_make_linear_attention_layer: Conv1d, L2 norm, Scan recurrence, gated norm_make_causal_mask: CumSum-based mask shared across attention layers_make_rotary_caches: Pre-computed cos/sin + h/w interleaving masks_build_scan_body: GatedDeltaNet delta rule in ONNX Scan body graphmake_genai_config: Text-only config (reuses base class viasuper())src/python/py/models/builders/__init__.pyQwen35Modelsrc/python/py/models/builder.pyQwen3_5ForConditionalGeneration→Qwen35ModelExport command