Skip to content

Add Qwen3.5 hybrid decoder export support (GatedDeltaNet + Attention)#2043

Open
apsonawane wants to merge 9 commits intomainfrom
asonawane/model-builder-qwen3.5
Open

Add Qwen3.5 hybrid decoder export support (GatedDeltaNet + Attention)#2043
apsonawane wants to merge 9 commits intomainfrom
asonawane/model-builder-qwen3.5

Conversation

@apsonawane
Copy link
Contributor

This PR adds support for exporting Qwen3.5 hybrid decoder models to ONNX via the
onnxruntime-genai model builder. Qwen3.5 uses a novel hybrid architecture that
alternates between standard attention layers (with KV cache) and GatedDeltaNet
linear attention layers (with conv_state + recurrent_state).

Architecture

Qwen3.5 has two layer types controlled by full_attention_interval:

  • Full attention (every 4th layer): Doubled Q projection (Q + output gate),
    per-head QK OffsetRMSNorm, partial interleaved mRoPE, opset 23 Attention op
  • Linear attention (remaining layers): GatedDeltaNet with depthwise causal
    conv1d, L2-normalized Q/K, Scan-based recurrence, gated RMSNorm

Key design decisions

  • Opset 23 standard-domain ops (RMSNormalization, RotaryEmbedding,
    Attention) instead of com.microsoft contrib ops for portability
  • CumSum-based causal mask with [B, 1, S, total_S] shape — correctly
    handles both prefill and single-token decode modes
  • Pre-computed cos/sin cache tables with Where-mask interleaving for
    efficient mRoPE at runtime
  • Hybrid cache I/O: Per-layer KV cache for attention layers,
    conv_state + recurrent_state for GatedDeltaNet layers
  • SkipSimplifiedLayerNormalization with pre-baked weight + 1 offset
    (reuses base class)

Changes

src/python/py/models/builders/qwen.py

  • Added Qwen35Model class with:
    • _make_full_attention_layer: Q/gate split, QK norm, mRoPE, Attention, gating
    • _make_linear_attention_layer: Conv1d, L2 norm, Scan recurrence, gated norm
    • _make_causal_mask: CumSum-based mask shared across attention layers
    • _make_rotary_caches: Pre-computed cos/sin + h/w interleaving masks
    • _build_scan_body: GatedDeltaNet delta rule in ONNX Scan body graph
    • make_genai_config: Text-only config (reuses base class via super())

src/python/py/models/builders/__init__.py

  • Export Qwen35Model

src/python/py/models/builder.py

  • Dispatch Qwen3_5ForConditionalGenerationQwen35Model

Export command

cd src/python/py/models
python builder.py -m "Qwen/Qwen3.5-0.8B" -o <output_dir> -p fp32 -e cpu
python builder.py -m "Qwen/Qwen3.5-0.8B" -o <output_dir> -p int4 -e cpu

@apsonawane apsonawane marked this pull request as ready for review March 24, 2026 06:36
Copilot AI review requested due to automatic review settings March 24, 2026 06:36
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds ONNX export support in the Python model builder for Qwen3.5 hybrid decoder models (alternating KV-cache attention layers and GatedDeltaNet recurrent/conv linear-attention layers), integrating with the existing onnxruntime-genai builder + config pipeline.

Changes:

  • Introduces Qwen35TextModel builder implementing full-attention + linear-attention layer graph construction, plus hybrid cache I/O (KV + conv/recurrent state).
  • Exposes the new builder from builders/__init__.py and wires up dispatch in builder.py for Qwen3_5ForConditionalGeneration.
  • Minor cleanups in builder.py option parsing / formatting.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/python/py/models/builders/qwen.py Adds the Qwen3.5 hybrid decoder builder implementation, rotary cache generation, causal mask construction, Scan recurrence body, and config patching.
src/python/py/models/builders/__init__.py Re-exports Qwen35TextModel.
src/python/py/models/builder.py Dispatches Qwen3.5 HF architecture to Qwen35TextModel and makes small option/formatting tweaks.



import json
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to import json and os here to patch issues with the produced files. We can fix their contents before they are serialized to disk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. make_genai_config now temporarily adjusts self attributes (num_layers, model_type, past_present_share_buffer, KV cache template keys) so the base class produces the correct config directly — no post-hoc patching.

del store[f"present.{suffix}"]

# Build per-layer cache I/O
for i in range(self.num_layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using new naming formats for the full attention layers? We can re-use the existing formats and remove the per-layer entries that aren't needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The full-attention layers now reuse the base class's existing KV cache naming format [past_key_values.{i}.key] / [present.{i}.key]— we just filter the template lists to exclude linear-attention layer indices instead of deleting and recreating them. Also removed the redundant is_layer and has_final_norm overrides (base class handles both), and restructured make_layer/make_attention to follow the standard dispatch pattern used by other models.

"""Build one decoder layer. Dispatches to full attention or
GatedDeltaNet linear attention based on self.layer_types."""

if self.layer_types[layer_id] == "linear_attention":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ideally handle the alternating attention layer types by using the same logic that is done for other models: override the make_attention method and dispatch to the right make method based on the layer id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added a [make_attention] override that dispatches to [_make_full_attention] or [_make_linear_attention] based on [self.layer_types[layer_id]]. The [make_layer] override only exists because linear-attention layers use [layer.linear_attn] instead of [layer.self_attn]— it picks the right module and delegates everything else (layernorm, MLP) to the base class pattern. The attention methods only handle the attention-specific subgraph and end by setting [self.layernorm_attrs["skip_input"]].

q_norm_name = f"/model/layers.{layer_id}/attn/q_norm/RMSNormalization"
q_norm_output = f"{q_norm_name}/output_0"
self.make_node(
"RMSNormalization",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified that the ONNX RMSNormalization op has the same performance as the ORT SimplifiedLayerNormalization op? Can we also check that other IHVs have implemented this op for their EPs? Otherwise, they will likely modify this code to use SimplifiedLayerNormalization instead. I would like to avoid EP-specific logic as much as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes true, RMSNormalization is a standard op. Replaced it with SimplifiedLayerNormalization

self.attention_attrs["k_path"] = f"{k_reshape_2}/output_0"

def _make_causal_mask(self):
"""Build causal attention mask [B, 1, S, total_S] for Attention op.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the existing causal mask that gets built insufficient?

def make_attention_mask_reformatting_for_mha(self):
# Make nodes for the attention mask subgraphs that reformat the
# 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T)

Copy link
Contributor Author

@apsonawane apsonawane Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3.5 uses standard domain Attention Op which expects float mask. So created a new function for that.

basename = "/model/causal_mask"
attn_mask = self.input_names["attention_mask"] # [B, total_S]

# Constants
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constant nodes have a special naming format. The make_constant method assists with this. Can we use that instead?

def make_constant(self, name):
# Make constant ops for 0, 1, 2, 3, etc.
# Format of name is "/model/constants/{dtype}/{num}"
path = name.split("/")
onnx_dtype = ir.DataType[path[-2]]
num = ast.literal_eval(path[-1])
assert isinstance(num, (int, float, list, tuple)), f"Invalid constant value: {num}"
tensor = ir.tensor(num, dtype=onnx_dtype, name=name)
node_name = name.replace("constants", "constant_nodes")
self.make_node("Constant", inputs=[], outputs=[name], name=node_name, value=tensor)
self.make_value(name, onnx_dtype, shape=[])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)

# Step 2: Get query length S from inputs_embeds shape
embeds_input = self.input_names.get("inputs_embeds", "inputs_embeds")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having default names if the input name does not exist in the dictionary, let's access self.input_names["inputs_embeds"] directly. If an error occurs, that is an indication that something should have been changed earlier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

def _get_shared_q_scale(self, head_dim):
"""Return the name of a shared 1/sqrt(head_dim) constant (created once)."""
name = "/model/constants/q_scale"
if name not in self.node_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to do these checks. These node names are automatically checked for duplicates when inserted.

if name in self.node_names:
# Note:
#
# This approach allows functions that make similar subgraphs with the same naming schema
# to share existing nodes without needing to know whether the nodes already exist or not
# (e.g. attention mask subgraphs).
#
# This means that the nodes can be created in those functions regardless of their actual
# status in the graph. This checks can then decide whether the proposed node actually
# needs to be added into the graph or not.
return

Copy link
Contributor Author

@apsonawane apsonawane Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not calling make_node which checks for duplicates. But make_initializer calls make_value which uses setdefault which returns the exisiting one so the dedupe is not necessary removing the condition

def _get_shared_l2_eps(self):
"""Return the name of a shared L2 epsilon constant (created once)."""
name = "/model/constants/l2_eps"
if name not in self.node_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

Copy link
Contributor Author

@apsonawane apsonawane Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not calling make_node which checks for duplicates. But make_initializer calls make_value which uses setdefault which returns the exisiting one so the dedupe is not necessary removing the condition

)
gc_name = f"{basename}/{suffix}/dim{dim_idx}/Gather_cache"
gc_out = f"{gc_name}/output_0"
self.make_node("Gather", [cache_name, sq_out], [gc_out], name=gc_name, axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use make_gather here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use make_gather

)
return unflat_out

def _make_l2_normalize(self, basename, input_name, last_dim, leading_dims=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the ReduceL2 op instead of these primitive ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, using ReduceL2

# Rename reshape output to target present state name
if present_state_name:
self.make_node(
"Identity",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed Identity

def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType:
int4_cpu = precision == "int4" and execution_provider == "cpu"
fp32_webgpu = execution_provider == "webgpu" and extra_options.get("use_webgpu_fp32", False)
bf16_cuda = precision == "int4" and execution_provider in {"cuda", "trt-rtx"} and extra_options.get("use_cuda_bf16", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the spacings in this file as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ran lintrunner, that adjusted this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the spacings

outputs=(),
nodes=(),
opset_imports={"": 21, "com.microsoft": 1},
opset_imports={"": 23, "com.microsoft": 1},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check which ops change their signature from 21 to 23 so that we don't break other models?


# 3D position_ids for mRoPE: [3, batch_size, sequence_length]
self.input_shapes["position_ids"] = [3, "batch_size", "sequence_length"]
if "position_ids" not in self.input_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe position_ids is always in the input names dictionary by default now. Can you check this and remove this if block if it always is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

self.mrope_rotary_dim = int(self.rope_attrs["partial_rotary_factor"] * self.head_size)

# Force RoPE computation in float32 for numerical stability
if "rope_cast" not in self.attention_attrs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given self.rope_attrs exists, can the casting setting be moved to that dictionary instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

self.attention_attrs["rope_cast"]["use_fp32"] = True

# Pre-compute cos/sin cache tables and interleaving masks for mRoPE
self._make_rotary_caches()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a method called make_rotary_embedding_caches_from_scratch which already does this and should be overwritten. However, given how much AI-generated code has been added into the model builder lately, all of it will need a refactoring. This is fine for now as I will clean this up later.

self.attention_attrs["q_norm"] = True
self.attention_attrs["k_norm"] = True
# Disable packed matmul since Q projection is doubled (4096 vs normal 2048)
self.attention_attrs["use_packed_matmul"] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does packed MatMul have to be disabled for this reason? Even if the hidden sizes for Q, K, V differ, the packed MatMul should still work. The use of QK norm is a better reason for why this should be disabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is not required, The base class already disables packed_matmul when q_norm is true

# full-attention layer indices. The base class creates entries for
# all num_layers — we remove linear-attention layers that use
# conv/recurrent state instead of KV cache.
attn_indices = {i for i, lt in enumerate(self.layer_types) if lt == "full_attention"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.layer_types exists, why do we need to calculate attn_indices separately? I think it would be better to loop through self.layer_types and merge both of the following for-loops together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged into a single loop over self.layer_types. Each iteration either keeps the KV cache entry (full_attention) or adds conv/recurrent state entries (linear_attention). Eliminated the separate attn_indices set and the second loop.

self.linear_value_head_dim,
]

self.output_names[f"present_state.{i}.conv"] = f"present.{i}.conv_state"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that the past/present naming schemas match the same schemas that will be used in the LFM2 models so that we don't have to duplicate work across both PRs.

self.make_node("Mul", [attn_output, sigmoid_output], [gated_output], name=gated_name)
self.make_value(gated_output, self.io_dtype, ["batch_size", "sequence_length", q_size])

# 8. Output projection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can re-use make_attention_output_proj here.

def make_attention_output_proj(self, layer_id, attention, root_input, **kwargs):
attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}"
attn_output = f"{attn_name}/output_0"
# Make MatMul node (output projection weight node)
o_proj = (
"o_proj" if hasattr(attention, "o_proj")
else "out_proj" if hasattr(attention, "out_proj")
else "dense"
)
o_matmul_basename = f"/model/layers.{layer_id}/attn/o_proj/MatMul"
o_weight = getattr(attention, o_proj)
o_matmul_name = self.make_matmul(o_weight, o_matmul_basename, attn_output)
# Make Add node (output projection bias node if bias exists)
o_bias_exists = getattr(attention, o_proj).bias is not None
if o_bias_exists:
o_add_name = f"/model/layers.{layer_id}/attn/o_proj/Add"
o_bias = getattr(attention, o_proj).bias
self.make_add_bias(o_bias, o_add_name, root_input=f"{o_matmul_name}/output_0")
# Assign output 0 of previous output node as skip input to next SkipLayerNorm
self.layernorm_attrs["skip_input"] = f"{o_matmul_name if not o_bias_exists else o_add_name}/output_0"

# SimplifiedLayerNormalization (com.microsoft, no offset for gated norm)
norm_name = f"{basename}/SimplifiedLayerNormalization"
norm_output = f"{norm_name}/output_0"
self.make_node(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use make_layernorm so that the right domain for the op gets registered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_layernorm can't be used here because it's tightly coupled to the decoder layer pipeline (expects [B, S, hidden_size] shapes). The gated RMSNorm operates on [B*S*N, hv] tensors with no skip connection. To use make_layernorm we need a major refactor instead this is simpler. We can add domain here so that it is registered

Comment on lines +1459 to +1461
# if self.ep in {"cuda"}:
# self._make_linear_attention_fused(layer_id, linear_attn, root_input)
# else:

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

Copilot Autofix

AI about 18 hours ago

To fix the problem, remove the commented-out control-flow code and keep only active, real code plus an explanatory TODO. This eliminates misleading “dead” logic while retaining documentation of the intended future change.

Concretely, in src/python/py/models/builders/qwen.py, in the _make_linear_attention method around lines 1458–1462, delete the commented if self.ep in {"cuda"}: / _make_linear_attention_fused / else: lines, leaving the TODO and the active call to _make_linear_attention_decomposed. No new imports or definitions are needed, and existing functionality remains unchanged because the decomposed path is already the only one executed.

Suggested changeset 1
src/python/py/models/builders/qwen.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py
--- a/src/python/py/models/builders/qwen.py
+++ b/src/python/py/models/builders/qwen.py
@@ -1456,9 +1456,6 @@
         ``LinearAttention`` ops once ORT registers kernels for them.
         """
         # TODO: Enable fused path when ORT has CausalConvWithState/LinearAttention kernels
-        # if self.ep in {"cuda"}:
-        #     self._make_linear_attention_fused(layer_id, linear_attn, root_input)
-        # else:
         self._make_linear_attention_decomposed(layer_id, linear_attn, root_input)
 
     def _make_linear_attention_fused(self, layer_id, linear_attn, root_input):
EOF
@@ -1456,9 +1456,6 @@
``LinearAttention`` ops once ORT registers kernels for them.
"""
# TODO: Enable fused path when ORT has CausalConvWithState/LinearAttention kernels
# if self.ep in {"cuda"}:
# self._make_linear_attention_fused(layer_id, linear_attn, root_input)
# else:
self._make_linear_attention_decomposed(layer_id, linear_attn, root_input)

def _make_linear_attention_fused(self, layer_id, linear_attn, root_input):
Copilot is powered by AI and may make mistakes. Always verify output.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants