Skip to content

Commit 9634025

Browse files
committed
code review fixes
1 parent fac5879 commit 9634025

File tree

3 files changed

+54
-41
lines changed

3 files changed

+54
-41
lines changed

src/core/dev_api/openvino/core/bound_evaluation_util.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@ OPENVINO_API Tensor evaluate_upper_bound(const Output<Node>& output);
3737
/// \return pair with Tensors for lower and upper value estimation.
3838
OPENVINO_API std::pair<Tensor, Tensor> evaluate_both_bounds(const Output<Node>& output);
3939

40-
/// \brief Compares two tensors for equality (shape, type, and data).
41-
/// Used for bounds comparison in replace_source_output.
42-
/// \param a First tensor to compare.
43-
/// \param b Second tensor to compare.
44-
/// \return True if tensors are equal (or both empty), false otherwise.
45-
OPENVINO_API bool tensors_equal(const Tensor& a, const Tensor& b);
46-
47-
/// \brief Force invalidates bounds on a tensor, bypassing SkipInvalidation attribute.
40+
/// \brief Checks if two outputs have the same bounds (lower and upper values).
41+
/// Returns true if both have no bounds, or if both have bounds with equal values.
42+
/// \note Unlike internal are_equal(), returns true when BOTH outputs have no bounds
43+
/// (semantically: "no bounds == no bounds").
44+
/// \param lhs First output to compare bounds from.
45+
/// \param rhs Second output to compare bounds to.
46+
/// \return True if bounds are equivalent, false otherwise.
47+
OPENVINO_API bool have_same_bounds(const Output<Node>& lhs, const Output<Node>& rhs);
48+
49+
/// \brief Sets force invalidation on tensor bounds, bypassing SkipInvalidation attribute.
4850
/// Temporarily removes SkipInvalidation, calls invalidate_values(), then restores it.
4951
/// \param tensor Tensor descriptor to invalidate bounds on.
50-
OPENVINO_API void force_invalidate_bounds(descriptor::Tensor& tensor);
52+
OPENVINO_API void set_bounds_to_invalidate(descriptor::Tensor& tensor);
5153

5254
} // namespace util
5355
} // namespace ov

src/core/src/bound_evaluate.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,9 @@ bool ov::default_symbol_evaluator(const Node* node,
739739
return false;
740740
}
741741

742-
bool ov::util::tensors_equal(const Tensor& a, const Tensor& b) {
742+
namespace {
743+
// Internal helper - compares tensor data, returns true if both empty
744+
bool bounds_data_equal(const ov::Tensor& a, const ov::Tensor& b) {
743745
// Both empty = equal
744746
if (!a && !b) {
745747
return true;
@@ -759,8 +761,16 @@ bool ov::util::tensors_equal(const Tensor& a, const Tensor& b) {
759761
// Compare data
760762
return std::memcmp(a.data(), b.data(), a.get_byte_size()) == 0;
761763
}
764+
} // namespace
765+
766+
bool ov::util::have_same_bounds(const Output<Node>& lhs, const Output<Node>& rhs) {
767+
const auto& lhs_tensor = lhs.get_tensor();
768+
const auto& rhs_tensor = rhs.get_tensor();
769+
return bounds_data_equal(lhs_tensor.get_lower_value(), rhs_tensor.get_lower_value()) &&
770+
bounds_data_equal(lhs_tensor.get_upper_value(), rhs_tensor.get_upper_value());
771+
}
762772

763-
void ov::util::force_invalidate_bounds(descriptor::Tensor& tensor) {
773+
void ov::util::set_bounds_to_invalidate(descriptor::Tensor& tensor) {
764774
const auto& skip_type = ov::SkipInvalidation::get_type_info_static();
765775
auto& rt_info = tensor.get_rt_info();
766776
bool had_skip = rt_info.count(skip_type) > 0;

src/core/src/descriptor/input.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@ ov::descriptor::Input::~Input() {
3333
}
3434

3535
void ov::descriptor::Input::replace_output(Output& new_output) {
36-
Output* old_output = m_output;
37-
36+
// Save old bounds BEFORE remove_input() because the old output's node may be destroyed
37+
// after we disconnect (if this was the last reference to it)
38+
ov::Tensor old_lower, old_upper;
3839
if (m_output != nullptr) {
40+
const auto& old_tensor = m_output->get_tensor();
41+
old_lower = old_tensor.get_lower_value();
42+
old_upper = old_tensor.get_upper_value();
3943
m_output->remove_input(this);
4044
}
4145
new_output.add_input(this);
@@ -54,35 +58,32 @@ void ov::descriptor::Input::replace_output(Output& new_output) {
5458
// This ensures:
5559
// - OptimizeSymbolsUsedAsValues: Same bounds → no invalidation → optimization works
5660
// - AbsSinking: New Abs node has no bounds yet → invalidation → correct recalculation
57-
if (old_output != nullptr && m_node != nullptr) {
58-
const auto& old_tensor = old_output->get_tensor();
59-
const auto& new_tensor = new_output.get_tensor();
60-
61-
const auto& old_lower = old_tensor.get_lower_value();
62-
const auto& old_upper = old_tensor.get_upper_value();
63-
const auto& new_lower = new_tensor.get_lower_value();
64-
const auto& new_upper = new_tensor.get_upper_value();
65-
61+
if (m_node != nullptr) {
6662
bool old_has_bounds = old_lower && old_upper;
67-
bool new_has_bounds = new_lower && new_upper;
68-
69-
// Invalidate if:
70-
// 1. Old had bounds but new doesn't have bounds (node was replaced with newly created one)
71-
// 2. Both have bounds but they differ
72-
bool should_invalidate = false;
73-
if (old_has_bounds && !new_has_bounds) {
74-
// New source doesn't have bounds yet (e.g., newly created Abs in AbsSinking)
75-
should_invalidate = true;
76-
} else if (old_has_bounds && new_has_bounds) {
77-
// Both have bounds - check if they differ
78-
bool bounds_differ =
79-
!ov::util::tensors_equal(old_lower, new_lower) || !ov::util::tensors_equal(old_upper, new_upper);
80-
should_invalidate = bounds_differ;
81-
}
82-
83-
if (should_invalidate) {
84-
for (size_t port = 0; port < m_node->get_output_size(); ++port) {
85-
ov::util::force_invalidate_bounds(m_node->get_output_tensor(port));
63+
if (old_has_bounds) {
64+
const auto& new_tensor = new_output.get_tensor();
65+
const auto& new_lower = new_tensor.get_lower_value();
66+
const auto& new_upper = new_tensor.get_upper_value();
67+
68+
// Internal comparison lambda (can't use have_same_bounds since old output is already gone)
69+
auto tensors_match = [](const ov::Tensor& a, const ov::Tensor& b) {
70+
if (!a && !b)
71+
return true;
72+
if (!a || !b)
73+
return false;
74+
if (a.get_shape() != b.get_shape())
75+
return false;
76+
if (a.get_element_type() != b.get_element_type())
77+
return false;
78+
return std::memcmp(a.data(), b.data(), a.get_byte_size()) == 0;
79+
};
80+
81+
bool same_bounds = tensors_match(old_lower, new_lower) && tensors_match(old_upper, new_upper);
82+
83+
if (!same_bounds) {
84+
for (size_t port = 0; port < m_node->get_output_size(); ++port) {
85+
ov::util::set_bounds_to_invalidate(m_node->get_output_tensor(port));
86+
}
8687
}
8788
}
8889
}

0 commit comments

Comments
 (0)