@@ -33,9 +33,13 @@ ov::descriptor::Input::~Input() {
3333}
3434
3535void ov::descriptor::Input::replace_output (Output& new_output) {
36- Output* old_output = m_output;
37-
36+ // Save old bounds BEFORE remove_input() because the old output's node may be destroyed
37+ // after we disconnect (if this was the last reference to it)
38+ ov::Tensor old_lower, old_upper;
3839 if (m_output != nullptr ) {
40+ const auto & old_tensor = m_output->get_tensor ();
41+ old_lower = old_tensor.get_lower_value ();
42+ old_upper = old_tensor.get_upper_value ();
3943 m_output->remove_input (this );
4044 }
4145 new_output.add_input (this );
@@ -54,35 +58,32 @@ void ov::descriptor::Input::replace_output(Output& new_output) {
5458 // This ensures:
5559 // - OptimizeSymbolsUsedAsValues: Same bounds → no invalidation → optimization works
5660 // - AbsSinking: New Abs node has no bounds yet → invalidation → correct recalculation
57- if (old_output != nullptr && m_node != nullptr ) {
58- const auto & old_tensor = old_output->get_tensor ();
59- const auto & new_tensor = new_output.get_tensor ();
60-
61- const auto & old_lower = old_tensor.get_lower_value ();
62- const auto & old_upper = old_tensor.get_upper_value ();
63- const auto & new_lower = new_tensor.get_lower_value ();
64- const auto & new_upper = new_tensor.get_upper_value ();
65-
61+ if (m_node != nullptr ) {
6662 bool old_has_bounds = old_lower && old_upper;
67- bool new_has_bounds = new_lower && new_upper;
68-
69- // Invalidate if:
70- // 1. Old had bounds but new doesn't have bounds (node was replaced with newly created one)
71- // 2. Both have bounds but they differ
72- bool should_invalidate = false ;
73- if (old_has_bounds && !new_has_bounds) {
74- // New source doesn't have bounds yet (e.g., newly created Abs in AbsSinking)
75- should_invalidate = true ;
76- } else if (old_has_bounds && new_has_bounds) {
77- // Both have bounds - check if they differ
78- bool bounds_differ =
79- !ov::util::tensors_equal (old_lower, new_lower) || !ov::util::tensors_equal (old_upper, new_upper);
80- should_invalidate = bounds_differ;
81- }
82-
83- if (should_invalidate) {
84- for (size_t port = 0 ; port < m_node->get_output_size (); ++port) {
85- ov::util::force_invalidate_bounds (m_node->get_output_tensor (port));
63+ if (old_has_bounds) {
64+ const auto & new_tensor = new_output.get_tensor ();
65+ const auto & new_lower = new_tensor.get_lower_value ();
66+ const auto & new_upper = new_tensor.get_upper_value ();
67+
68+ // Internal comparison lambda (can't use have_same_bounds since old output is already gone)
69+ auto tensors_match = [](const ov::Tensor& a, const ov::Tensor& b) {
70+ if (!a && !b)
71+ return true ;
72+ if (!a || !b)
73+ return false ;
74+ if (a.get_shape () != b.get_shape ())
75+ return false ;
76+ if (a.get_element_type () != b.get_element_type ())
77+ return false ;
78+ return std::memcmp (a.data (), b.data (), a.get_byte_size ()) == 0 ;
79+ };
80+
81+ bool same_bounds = tensors_match (old_lower, new_lower) && tensors_match (old_upper, new_upper);
82+
83+ if (!same_bounds) {
84+ for (size_t port = 0 ; port < m_node->get_output_size (); ++port) {
85+ ov::util::set_bounds_to_invalidate (m_node->get_output_tensor (port));
86+ }
8687 }
8788 }
8889 }
0 commit comments