Skip to content

Commit 5884aeb

Browse files
authored
Fix operator tests (openvinotoolkit#34476)
### Details: - Fixed operator tests for executorch to make OpenVINO production ready in executorch. - Enable ops for yolov26, gemma etc. ### Tickets: - ([https://jira.devtools.intel.com/browse/CVS-176207](https://jira.devtools.intel.com/browse/CVS-176207)) ### AI Assistance: - *AI assistance used: yes* - For research about some ops. - Tested with pytorch layer tests using tags precommit_fx_backend.
1 parent 303e882 commit 5884aeb

File tree

11 files changed

+65
-33
lines changed

11 files changed

+65
-33
lines changed

src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,11 @@ def get_input_strides(self, index: int) -> list:
367367
strides = list(meta["tensor_meta"].stride)
368368
if strides:
369369
return strides
370+
# Fallback for Edge dialect nodes where tensor_meta is absent but val is present
371+
if "val" in meta and hasattr(meta["val"], "stride"):
372+
strides = list(meta["val"].stride())
373+
if strides:
374+
return strides
370375
return []
371376

372377
def get_input_type(self, index):

src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ def _call(*args):
102102
example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
103103
model = subgraph
104104
else:
105-
from torch._subclasses.fake_tensor import FakeTensorMode
106-
107105
decompositions = _get_decompositions(options) + get_inf_decomposition_list()
108-
with FakeTensorMode(allow_non_fake_inputs=True) as fakemode:
109-
fake_inputs = [fakemode.from_tensor(x) for x in example_inputs]
110-
model = make_fx(subgraph, decomposition_table=get_decompositions(decompositions))(*fake_inputs)
111-
106+
model = make_fx(
107+
subgraph,
108+
decomposition_table=get_decompositions(decompositions),
109+
tracing_mode="fake",
110+
_allow_non_fake_inputs=True,
111+
)(*example_inputs)
112112
with torch.no_grad():
113113
model.eval()
114114
partitioner = Partitioner(options)

src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,13 @@ def __init__(self, options):
208208
"torch.ops.aten.reciprocal.default": None,
209209
"torch.ops.aten.relu.default": None,
210210
"torch.ops.aten.relu_.default": None,
211+
"torch.ops.aten.remainder.default": None,
212+
"torch.ops.aten.remainder.Scalar": None,
213+
"torch.ops.aten.remainder.Tensor": None,
211214
"torch.ops.aten.repeat.default": None,
212215
"torch.ops.aten.roll.default": None,
216+
"torch.ops.aten.round.default": None,
217+
"torch.ops.aten.round.out": None,
213218
"torch.ops.aten.rsqrt.default": None,
214219
"torch.ops.aten.rsub.Scalar": None,
215220
"torch.ops.aten.rsub.Tensor": None,
@@ -257,6 +262,11 @@ def __init__(self, options):
257262
"torch.ops.aten.unsqueeze_copy.default": None,
258263
"torch.ops.aten.upsample_nearest2d.default": None,
259264
"torch.ops.aten.upsample_nearest2d.vec": None,
265+
"torch.ops.aten.upsample_bicubic2d.default": None,
266+
"torch.ops.aten.upsample_bicubic2d.vec": None,
267+
"torch.ops.aten.upsample_bilinear2d": None,
268+
"torch.ops.aten.upsample_bilinear2d.default": None,
269+
"torch.ops.aten.upsample_bilinear2d.vec": None,
260270
"torch.ops.aten.upsample_nearest3d.vec": None,
261271
"torch.ops.aten.var.correction": None,
262272
"torch.ops.aten.var_mean.correction": None,
@@ -272,7 +282,7 @@ def __init__(self, options):
272282
"torch.ops.quantized_decomposed.quantize_per_tensor.default": None,
273283
"torch.ops.quantized_decomposed.quantize_per_channel.default": None,
274284
"torch.ops.quantized_decomposed.dequantize_per_tensor.default": None,
275-
"torch.ops.quantized_decomposed.dequantize_per_channel.default": None
285+
"torch.ops.quantized_decomposed.dequantize_per_channel.default": None,
276286
}
277287

278288
self.enabled_op_names = []

src/frontends/pytorch/src/op/min_max.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5+
#include "openvino/core/validation_util.hpp"
56
#include "openvino/frontend/pytorch/node_context.hpp"
67
#include "openvino/op/constant.hpp"
78
#include "openvino/op/convert.hpp"
@@ -21,6 +22,17 @@ namespace op {
2122

2223
using namespace ov::op;
2324

25+
// Returns true when the output is an empty list
26+
static bool is_empty_axes(const Output<Node>& dims_input) {
27+
if (is_empty_list(dims_input)) {
28+
return true;
29+
}
30+
if (const auto constant_src = ov::util::get_constant_from_source(dims_input)) {
31+
return constant_src->get_shape() == Shape{0};
32+
}
33+
return false;
34+
}
35+
2436
OutputVector translate_max(const NodeContext& context) {
2537
// torch.max (same for torch.min) actually has two interfaces smashed together:
2638
// torch.max(x, dim, keepdim) and torch.max(x, y)
@@ -175,14 +187,16 @@ OutputVector translate_amin(const NodeContext& context) {
175187
// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
176188

177189
// aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
178-
num_inputs_check(context, 2, 4);
190+
num_inputs_check(context, 1, 4);
179191
auto x = context.get_input(0);
180-
auto dims = context.get_input(1);
192+
// From torch 2.8, amin dim input is optional; an empty list also means reduce all dims
193+
auto dim_input = !context.input_is_none(1) ? context.get_input(1) : Output<Node>{};
194+
auto dim = (dim_input.get_node() == nullptr || is_empty_axes(dim_input)) ? get_axes_range(context, 0) : dim_input;
181195
bool keep_dims = false;
182196
if (!context.input_is_none(2)) {
183197
keep_dims = context.const_input<bool>(2);
184198
}
185-
auto res = context.mark_node(std::make_shared<v1::ReduceMin>(x, dims, keep_dims));
199+
auto res = context.mark_node(std::make_shared<v1::ReduceMin>(x, dim, keep_dims));
186200
if (!context.input_is_none(3)) {
187201
context.mutate_input(3, res);
188202
}
@@ -193,14 +207,16 @@ OutputVector translate_amax(const NodeContext& context) {
193207
// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
194208

195209
// aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
196-
num_inputs_check(context, 2, 4);
210+
num_inputs_check(context, 1, 4);
197211
auto x = context.get_input(0);
198-
auto dims = context.get_input(1);
212+
// From torch 2.8, amax dim input is optional; an empty list also means reduce all dims
213+
auto dim_input = !context.input_is_none(1) ? context.get_input(1) : Output<Node>{};
214+
auto dim = (dim_input.get_node() == nullptr || is_empty_axes(dim_input)) ? get_axes_range(context, 0) : dim_input;
199215
bool keep_dims = false;
200216
if (!context.input_is_none(2)) {
201217
keep_dims = context.const_input<bool>(2);
202218
}
203-
auto res = context.mark_node(std::make_shared<v1::ReduceMax>(x, dims, keep_dims));
219+
auto res = context.mark_node(std::make_shared<v1::ReduceMax>(x, dim, keep_dims));
204220
if (!context.input_is_none(3)) {
205221
context.mutate_input(3, res);
206222
}
@@ -213,8 +229,9 @@ OutputVector translate_aminmax(const NodeContext& context) {
213229

214230
auto input = context.get_input(0);
215231

216-
// check if dim is provided, if not, get the range of axes to compute min and max
217-
auto dim = !context.input_is_none(1) ? context.get_input(1) : get_axes_range(context, 0);
232+
// From torch 2.8, amax dim input is optional; an empty list also means reduce all dims
233+
auto dim_input = !context.input_is_none(1) ? context.get_input(1) : Output<Node>{};
234+
auto dim = (dim_input.get_node() == nullptr || is_empty_axes(dim_input)) ? get_axes_range(context, 0) : dim_input;
218235

219236
// check if keepdim is provided, if not, set it to false like PyTorch
220237
bool keep_dims = !context.input_is_none(2) ? context.const_input<bool>(2) : false;

src/frontends/pytorch/src/op/upsample.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ OutputVector base_translate_upsample(const NodeContext& context,
2121
v11::Interpolate::InterpolateMode interpolate_mode,
2222
size_t dims,
2323
bool antialias = false) {
24-
num_inputs_check(context, 1, 4);
24+
num_inputs_check(context, 1, 5);
2525
auto data = context.get_input(0);
2626
std::vector<size_t> pad(dims, 0);
2727
auto size_mode = v11::Interpolate::ShapeCalcMode::SIZES;

src/frontends/pytorch/src/op_table.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
10441044
{"aten.one_hot.default", op::translate_one_hot},
10451045
{"aten.outer.default", op::translate_outer},
10461046
{"aten.permute.default", op::translate_permute},
1047-
{"aten.permute_copy.default", op::translate_1to1_match_2_inputs<opset10::Transpose>},
1047+
{"aten.permute_copy.default", op::translate_permute},
10481048
{"aten.pow.Scalar", op::translate_pow},
10491049
{"aten.pow.Tensor_Scalar", op::translate_pow},
10501050
{"aten.pow.Tensor_Tensor", op::translate_pow},
@@ -1059,6 +1059,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
10591059
{"aten.reflection_pad3d.default", op::translate_reflection_pad_nd},
10601060
{"aten.relu.default", op::translate_1to1_match_1_inputs<opset10::Relu>},
10611061
{"aten.relu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
1062+
{"aten.remainder.default", op::translate_remainder},
10621063
{"aten.remainder.Scalar", op::translate_remainder},
10631064
{"aten.remainder.Tensor", op::translate_remainder},
10641065
{"aten.repeat.default", op::translate_repeat_fx},
@@ -1068,6 +1069,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
10681069
{"aten.roll.default", op::translate_roll},
10691070
{"aten.rad2deg.default", op::translate_rad2deg},
10701071
{"aten.round.default", op::translate_round},
1072+
{"aten.round.out", op::translate_round},
10711073
{"aten.rsqrt.default", op::translate_rsqrt},
10721074
{"aten.rsub.Scalar", op::translate_rsub_fx},
10731075
{"aten.rsub.Tensor", op::translate_rsub_fx},
@@ -1120,8 +1122,11 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
11201122
{"aten.unfold.default", op::translate_unfold},
11211123
{"aten.unsqueeze.default", common_translators::translate_unsqueeze},
11221124
{"aten.unsqueeze_copy.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
1125+
{"aten.upsample_bicubic2d.default", op::translate_upsample_bicubic2d},
11231126
{"aten.upsample_bicubic2d.vec", op::translate_upsample_bicubic2d},
1127+
{"aten.upsample_bilinear2d", op::translate_upsample_bilinear2d},
11241128
{"aten.upsample_bilinear2d.vec", op::translate_upsample_bilinear2d},
1129+
{"aten.upsample_bilinear2d.default", op::translate_upsample_bilinear2d},
11251130
{"aten.upsample_linear1d.vec", op::translate_upsample_linear1d},
11261131
{"aten.upsample_nearest1d.default", op::translate_upsample_nearest1d},
11271132
{"aten.upsample_nearest1d.vec", op::translate_upsample_nearest1d},

src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,6 @@ namespace pass {
3636

3737
using namespace ov::op;
3838

39-
namespace {
40-
Output<Node> generate_zeros_with_convertlike(ov::pass::NodeRegistry& rg,
41-
const Output<Node> sizes,
42-
const Output<Node> tensor_of_type) {
43-
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
44-
auto zeros = rg.make<v3::Broadcast>(const_0, sizes);
45-
return rg.make<v1::ConvertLike>(zeros, tensor_of_type);
46-
}
47-
} // namespace
48-
4939
AtenIndexPutReplacer::AtenIndexPutReplacer() {
5040
auto index_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>(
5141
fw_node_predicate({"aten::index_put_", "aten.index_put.default"}));
@@ -254,9 +244,7 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
254244

255245
std::shared_ptr<ov::Node> result;
256246
if (accumulate) {
257-
auto zeros = generate_zeros_with_convertlike(rg, input_shape, input);
258-
auto scatter = rg.make<v3::ScatterNDUpdate>(zeros, index, values);
259-
result = rg.make<v1::Add>(input, scatter);
247+
result = rg.make<v15::ScatterNDUpdate>(input, index, values, v15::ScatterNDUpdate::Reduction::SUM);
260248
} else {
261249
result = rg.make<v3::ScatterNDUpdate>(input, index, values);
262250
}

tests/layer_tests/pytorch_tests/test_min_max.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,13 @@ def __init__(self, op, axis, keep_dims, out):
322322
self.forward = self.forward_out
323323

324324
def forward_out(self, x, y):
325+
if self.axis is None:
326+
return self.op(x, out=y), y
325327
return self.op(x, self.axis, self.keep_dims, out=y), y
326328

327329
def forward(self, x):
330+
if self.axis is None:
331+
return self.op(x)
328332
return self.op(x, self.axis, self.keep_dims)
329333

330334

@@ -333,7 +337,7 @@ def forward(self, x):
333337
return model_cls, f"aten::{op_type}"
334338

335339
@pytest.mark.parametrize("op_type", ["amin", "amax"])
336-
@pytest.mark.parametrize("axis", [0, -1, 1, [1, 2], [-1, -2], [2, 0, -1], [0, 1, 2, 3]])
340+
@pytest.mark.parametrize("axis", [None, 0, -1, 1, [1, 2], [-1, -2], [2, 0, -1], [0, 1, 2, 3]])
337341
@pytest.mark.parametrize("keep_dims", [True, False])
338342
@pytest.mark.parametrize("out", [skip_if_export(True), False])
339343
@pytest.mark.parametrize("input_dtype", ['float32', 'int32', 'int64', 'float64'])

tests/layer_tests/pytorch_tests/test_remainder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def forward(self, lhs, rhs):
2828

2929
@pytest.mark.nightly
3030
@pytest.mark.precommit
31+
@pytest.mark.precommit_fx_backend
3132
def test_remainder(self, ie_device, precision, ir_version, input_shape_rhs):
3233
self.input_rhs = self.random.randn(*input_shape_rhs)
3334
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

tests/layer_tests/pytorch_tests/test_round.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def forward_out(self, x, y):
4040
@pytest.mark.nightly
4141
@pytest.mark.precommit
4242
@pytest.mark.precommit_torch_export
43+
@pytest.mark.precommit_fx_backend
4344
@pytest.mark.parametrize("out", [skip_if_export(True), False])
4445
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
4546
def test_round(self, out, dtype, ie_device, precision, ir_version):
@@ -88,4 +89,4 @@ def test_round(self, input_type, ie_device, precision, ir_version):
8889
self._prepare_input = self._prepare_input_int
8990
else:
9091
self._prepare_input = self._prepare_input_float
91-
self._test(*self.create_model(input_type), ie_device, precision, ir_version, trace_model=True)
92+
self._test(*self.create_model(input_type), ie_device, precision, ir_version, trace_model=True)

0 commit comments

Comments
 (0)