Skip to content

Commit 954e64d

Browse files
Merge branch 'main' into add-fp8-placeholder-support-for-serialization
2 parents 28f770a + 32a6cec commit 954e64d

109 files changed

Lines changed: 6377 additions & 1091 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/docker/build.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
4+
# Copyright 2026 Arm Limited and/or its affiliates.
45
#
56
# This source code is licensed under the BSD-style license found in the
67
# LICENSE file in the root directory of this source tree.
@@ -94,11 +95,6 @@ BUILD_DOCS=1
9495
# Copy requirements-lintrunner.txt from root to here
9596
cp ../../requirements-lintrunner.txt ./
9697

97-
# Copy arm setup script from root to here
98-
# TODO(huydhn): Figure out a way to rebuild the Docker image automatically
99-
# with a new image hash when the content here is updated
100-
cp -r ../../examples/arm/ ./arm
101-
10298
docker build \
10399
--no-cache \
104100
--progress=plain \

.ci/scripts/test_model_e2e.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ EOF
354354
fi
355355
;;
356356
qwen3_5_moe)
357-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
357+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
358358
;;
359359
voxtral_realtime)
360360
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"

.github/workflows/cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ jobs:
145145
# Run CUDA backend Python tests
146146
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
147147
148-
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
149-
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
148+
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
149+
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151151
export-model-cuda-artifact:
152152
name: export-model-cuda-artifact

backends/aoti/slim/c10/core/ScalarType.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ enum class ScalarType : int8_t {
2828
Short = 2, // int16_t
2929
Int = 3, // int32_t
3030
Long = 4, // int64_t
31-
// Half = 5, // float16 - not currently needed
31+
Half = 5, // float16
3232
Float = 6, // float
3333
// Double = 7, // double - not currently needed
3434
// ComplexHalf = 8,
@@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char;
4848
constexpr ScalarType kShort = ScalarType::Short;
4949
constexpr ScalarType kInt = ScalarType::Int;
5050
constexpr ScalarType kLong = ScalarType::Long;
51+
constexpr ScalarType kHalf = ScalarType::Half;
5152
constexpr ScalarType kFloat = ScalarType::Float;
5253
constexpr ScalarType kBool = ScalarType::Bool;
5354
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
@@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) {
6768
return sizeof(int32_t);
6869
case ScalarType::Long:
6970
return sizeof(int64_t);
71+
case ScalarType::Half:
72+
return 2; // sizeof(__half) = 2 bytes
7073
case ScalarType::Float:
7174
return sizeof(float);
7275
case ScalarType::Bool:
@@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) {
9396
return "Int";
9497
case ScalarType::Long:
9598
return "Long";
99+
case ScalarType::Half:
100+
return "Half";
96101
case ScalarType::Float:
97102
return "Float";
98103
case ScalarType::Bool:
@@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) {
110115
/// @param t The scalar type to check.
111116
/// @return true if the scalar type is floating point, false otherwise.
112117
inline bool isFloatingType(ScalarType t) {
113-
return t == ScalarType::Float || t == ScalarType::BFloat16;
118+
return t == ScalarType::Half || t == ScalarType::Float ||
119+
t == ScalarType::BFloat16;
114120
}
115121

116122
/// Checks if the scalar type is an integral type (including bool optionally).
@@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) {
149155
case ScalarType::Short:
150156
case ScalarType::Int:
151157
case ScalarType::Long:
158+
case ScalarType::Half:
152159
case ScalarType::Float:
153160
case ScalarType::Bool:
154161
case ScalarType::BFloat16:

backends/aoti/slim/c10/core/test/test_scalar_type.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ const std::vector<ScalarTypeTestData> kAllScalarTypes = {
3636
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
3737
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
3838
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
39+
{ScalarType::Half, 5, 2, "Half", true, false, false, false},
3940
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
4041
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
4142
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
@@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) {
128129
EXPECT_EQ(kLong, ScalarType::Long);
129130
}
130131

132+
TEST_F(ScalarTypeConstantsTest, KHalfConstant) {
133+
EXPECT_EQ(kHalf, ScalarType::Half);
134+
}
135+
131136
TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
132137
EXPECT_EQ(kFloat, ScalarType::Float);
133138
}
@@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
185190
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
186191
}
187192

193+
TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) {
194+
EXPECT_EQ(elementSize(ScalarType::Half), 2);
195+
}
196+
188197
TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
189198
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
190199
}
@@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
196205
TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
197206
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
198207
}
208+
209+
// =============================================================================
210+
// isValidScalarType Tests
211+
// =============================================================================
212+
213+
class IsValidScalarTypeTest : public ::testing::Test {};
214+
215+
TEST_F(IsValidScalarTypeTest, HalfIsValid) {
216+
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
217+
}
218+
219+
TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) {
220+
EXPECT_TRUE(isValidScalarType(ScalarType::Byte));
221+
EXPECT_TRUE(isValidScalarType(ScalarType::Char));
222+
EXPECT_TRUE(isValidScalarType(ScalarType::Short));
223+
EXPECT_TRUE(isValidScalarType(ScalarType::Int));
224+
EXPECT_TRUE(isValidScalarType(ScalarType::Long));
225+
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
226+
EXPECT_TRUE(isValidScalarType(ScalarType::Float));
227+
EXPECT_TRUE(isValidScalarType(ScalarType::Bool));
228+
EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16));
229+
}
230+
231+
TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) {
232+
EXPECT_FALSE(isValidScalarType(ScalarType::Undefined));
233+
}

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -459,8 +460,10 @@ class ET_EXPERIMENTAL MetalBackend final
459460

460461
ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs);
461462

463+
size_t n_io_sum = 0;
462464
ET_CHECK_OR_RETURN_ERROR(
463-
n_inputs + n_outputs == args.size(),
465+
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
466+
n_io_sum == args.size(),
464467
InvalidArgument,
465468
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
466469
n_inputs,

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Set, Type
7+
from collections.abc import Mapping
8+
from typing import Sequence, Set, Type
89

910
import torch._export.utils
1011
import torch.fx
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1920
FuseEqualPlaceholdersPass,
2021
)
22+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
2123
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2224
from executorch.backends.transforms.utils import (
2325
create_constant_placeholder,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355
super().__init__(*args, **kwargs)
5456
self.exported_program = exported_program
5557

58+
@staticmethod
59+
def _is_tosa_dialect_op(target) -> bool:
60+
target_str = str(target)
61+
return (
62+
"executorch.exir.dialects.backend._ops.tosa." in target_str
63+
or "<EdgeOpOverload: tosa." in target_str
64+
)
65+
66+
@staticmethod
67+
def _arg_contains_symbolic_shape(arg) -> bool:
68+
if isinstance(arg, torch.fx.Node):
69+
if meta_has_shape_mark(arg.meta):
70+
return True
71+
return FuseConstantArgsPass._arg_contains_symbolic_shape(
72+
arg.meta.get("val")
73+
)
74+
if isinstance(arg, torch.SymInt):
75+
return True
76+
if isinstance(arg, Mapping):
77+
return any(
78+
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
79+
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
80+
for k, v in arg.items()
81+
)
82+
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
83+
return any(
84+
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
85+
)
86+
return False
87+
5688
def _propagate_special_dtype(self, from_nodes, to_node, data):
5789
"""Propagate special dtype meta if it exists."""
5890
special_dtypes = set()
@@ -83,21 +115,24 @@ def _fuse_nodes(self, node) -> bool:
83115
input_nodes = list(node.all_input_nodes)
84116
qparams = node.meta.get("input_qparams", None)
85117

86-
def resolve_arg(arg):
118+
def resolve_arg(arg, arg_index=None):
119+
qparam = (
120+
qparams.get(arg_index) if qparams and arg_index is not None else None
121+
)
87122
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
88-
idx = input_nodes.index(arg)
89123
t = get_param_tensor(self.exported_program, arg)
90-
# Check if qparams exist for this arg
91-
if qparams and idx in qparams.keys():
92-
t = qparams[idx].dequantize_value(t)
124+
if qparam is not None:
125+
t = qparam.dequantize_value(t)
93126
return t
94127
if isinstance(arg, tuple):
95-
return tuple(resolve_arg(x) for x in arg)
128+
return tuple(resolve_arg(x, arg_index) for x in arg)
96129
if isinstance(arg, list):
97-
return [resolve_arg(x) for x in arg]
130+
return [resolve_arg(x, arg_index) for x in arg]
98131
return arg
99132

100-
new_args = tuple(resolve_arg(a) for a in node.args)
133+
new_args = tuple(
134+
resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args)
135+
)
101136
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}
102137

103138
data = node.target(*new_args, **new_kwargs)
@@ -139,13 +174,13 @@ def call(self, graph_module):
139174
for node in graph_module.graph.nodes:
140175
if node.op != "call_function":
141176
continue
142-
if node.target in [
143-
exir_ops.backend.tosa.MATMUL.default,
144-
exir_ops.backend.tosa.RESCALE.default,
145-
exir_ops.backend.tosa.RESIZE.default,
146-
exir_ops.backend.tosa.TABLE.default,
147-
exir_ops.backend.tosa.TRANSPOSE.default,
148-
]:
177+
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
178+
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179+
if (
180+
self._is_tosa_dialect_op(node.target)
181+
or self._arg_contains_symbolic_shape(node.args)
182+
or self._arg_contains_symbolic_shape(node.kwargs)
183+
):
149184
continue
150185

151186
input_nodes = node.all_input_nodes
@@ -161,7 +196,6 @@ def call(self, graph_module):
161196
)
162197
if not all(input_nodes_constant):
163198
continue
164-
165199
try:
166200
did_fuse = self._fuse_nodes(node)
167201
if did_fuse:

backends/arm/_passes/insert_table_ops.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
139139
"""Add buffer to self.exported_program.state_dict."""
140140
self.exported_program.state_dict[buffer_name] = buffer
141141

142+
@staticmethod
143+
def _get_8bit_table_domain() -> torch.Tensor:
144+
"""Return the canonical 8-bit TOSA TABLE input domain."""
145+
int8_info = torch.iinfo(torch.int8)
146+
# torch.arange excludes the end value, so use max + 1 to include 127.
147+
return torch.arange(
148+
int8_info.min,
149+
int8_info.max + 1,
150+
dtype=torch.int8,
151+
)
152+
142153
def generate_8bit_table_values(
143154
self,
144155
torch_op: Callable[[torch.Tensor], torch.Tensor],
@@ -157,17 +168,10 @@ def f(x: torch.Tensor) -> torch.Tensor:
157168
x = torch_op(x)
158169
return out_quantargs.quantize_value(x)
159170

160-
return (
161-
f(
162-
torch.linspace(
163-
start=in_quantargs.qmin,
164-
end=in_quantargs.qmax,
165-
steps=256,
166-
dtype=torch.int8,
167-
)
168-
).to(dtype=torch.int8),
169-
0,
171+
effective_codes = self._get_8bit_table_domain().clamp(
172+
in_quantargs.qmin, in_quantargs.qmax
170173
)
174+
return (f(effective_codes).to(dtype=torch.int8), 0)
171175

172176
def generate_16_bit_table_values(
173177
self,

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
get_input_qparams,
2222
get_output_qparams,
2323
)
24+
from executorch.backends.arm._passes.symbolic_value_range import (
25+
evaluate_symbolic_expr_values,
26+
)
2427
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2528
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2629
from executorch.backends.arm.tosa.specification import get_context_shape_env
@@ -83,16 +86,22 @@ def _adjust_pad_if_needed(
8386

8487
if isinstance(mod_remainder, torch.SymInt):
8588
shape_env = get_context_shape_env()
86-
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
87-
mod_remainder_upper = int(value_ranges.upper)
89+
exact_values = evaluate_symbolic_expr_values(
90+
mod_remainder.node.expr, shape_env
91+
)
92+
if exact_values is not None:
93+
mod_remainder_upper = max(exact_values)
94+
else:
95+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
96+
mod_remainder_upper = int(value_ranges.upper)
8897
if mod_remainder_upper == 0:
8998
mod_remainder = 0
9099
else:
91100
mod_remainder_upper = mod_remainder
92101

93102
if mod_remainder_upper > pad:
94103
raise RuntimeError(
95-
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
104+
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
96105
)
97106
return pad - mod_remainder
98107

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import cast, Sequence, Set, Type, TypeAlias
66

7+
import torch
78
import torch.fx
89
from executorch.backends.arm._passes import ArmPass
910
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,6 +13,9 @@
1213
)
1314
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1415
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
16+
from executorch.backends.arm._passes.symbolic_value_range import (
17+
evaluate_symbolic_expr_values,
18+
)
1519
from executorch.backends.arm.tosa.specification import get_context_shape_env
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721
from executorch.exir.pass_base import ExportPass, PassResult
@@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
4953
"""Returns whether an int or SymInt is greater than another value."""
5054
if isinstance(input, torch.SymInt):
5155
shape_env = get_context_shape_env()
56+
exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env)
57+
if exact_values is not None:
58+
return max(exact_values) > other
5259
value_ranges = shape_env.bound_sympy(input.node.expr)
5360
return value_ranges.upper > other
5461
else:

0 commit comments

Comments
 (0)