2525#include " openvino/pass/pattern/op/wrap_type.hpp"
2626#include " ov_ops/augru_cell.hpp"
2727
28- using namespace std ;
29- using namespace ov ::element;
30- using namespace ov ::pass::pattern;
28+ namespace ov ::pass {
29+
30+ namespace v0 = ov::op::v0;
31+ namespace v1 = ov::op::v1;
32+
33+ namespace {
3134
3235// The 1st input to the Add op is automatically broadcasted
3336// from 1d to 2d tensor, but to be compatible with what
@@ -37,8 +40,7 @@ static std::shared_ptr<ov::Node> get_bias_add(const std::shared_ptr<ov::Node>& b
3740 auto input_source_1_ps = bias_add->input_value (1 ).get_partial_shape ();
3841 if (input_source_1_ps.is_static () && input_source_1_ps.rank ().get_length () == 1 ) {
3942 auto unsqueeze =
40- rg.make <ov::op::v0::Unsqueeze>(bias_add->input_value (1 ),
41- ov::op::v0::Constant::create (ov::element::i32 , ov::Shape{}, {0 }));
43+ rg.make <v0::Unsqueeze>(bias_add->input_value (1 ), v0::Constant::create (ov::element::i32 , ov::Shape{}, {0 }));
4244 bias_add->input (1 ).replace_source_output (unsqueeze);
4345 }
4446
@@ -54,19 +56,20 @@ static std::shared_ptr<ov::Node> get_bias_add(const std::shared_ptr<ov::Node>& b
5456// compatible with the code of the transformation.
5557static std::shared_ptr<ov::Node> get_weights_matmul (const std::shared_ptr<ov::Node>& mat_mul,
5658 ov::pass::NodeRegistry& rg) {
57- if (auto matmul = ov::as_type_ptr<ov::op:: v0::MatMul>(mat_mul)) {
59+ if (auto matmul = ov::as_type_ptr<v0::MatMul>(mat_mul)) {
5860 if (!matmul->get_transpose_b ()) {
59- auto transpose =
60- rg.make <ov::op::v1::Transpose>(matmul->input_value (1 ),
61- ov::op::v0::Constant::create (ov::element::i32 , ov::Shape{2 }, {1 , 0 }));
61+ auto transpose = rg.make <v1::Transpose>(matmul->input_value (1 ),
62+ v0::Constant::create (ov::element::i32 , ov::Shape{2 }, {1 , 0 }));
6263 matmul->input (1 ).replace_source_output (transpose);
6364 }
6465 }
6566
6667 return mat_mul;
6768}
6869
69- ov::pass::AUGRUCellFusion::AUGRUCellFusion () {
70+ } // namespace
71+
72+ AUGRUCellFusion::AUGRUCellFusion () {
7073 MATCHER_SCOPE (AUGRUCellFusion);
7174
7275 // we can't determine hidden_size or input_size in this case
@@ -75,27 +78,28 @@ ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
7578 return !(p_shape.rank ().is_dynamic () || p_shape[1 ].is_dynamic ());
7679 };
7780
78- auto concat_1 = wrap_type<ov::op::v0::Concat>({any_input (is_first_dim_static), any_input (is_first_dim_static)});
79- auto matmul_1 = wrap_type<ov::op::v0::MatMul>({concat_1, any_input (is_first_dim_static)});
80- auto add_1 = wrap_type<ov::op::v1::Add>({matmul_1, any_input ()});
81+ auto concat_1 = pattern::wrap_type<v0::Concat>(
82+ {pattern::any_input (is_first_dim_static), pattern::any_input (is_first_dim_static)});
83+ auto matmul_1 = pattern::wrap_type<v0::MatMul>({concat_1, pattern::any_input (is_first_dim_static)});
84+ auto add_1 = pattern::wrap_type<v1::Add>({matmul_1, pattern::any_input ()});
8185 // only Sigmoid is supported in the current version of AUGRUCell
82- auto sigmoid = wrap_type<ov::op:: v0::Sigmoid>({add_1});
83- auto split = wrap_type<ov::op:: v1::Split>({sigmoid, any_input ()});
84- auto multiply = wrap_type<ov::op:: v1::Multiply>({split, any_input ()});
86+ auto sigmoid = pattern::wrap_type< v0::Sigmoid>({add_1});
87+ auto split = pattern::wrap_type< v1::Split>({sigmoid, pattern:: any_input ()});
88+ auto multiply = pattern::wrap_type< v1::Multiply>({split, pattern:: any_input ()});
8589
86- auto concat_2 = wrap_type<ov::op:: v0::Concat>({any_input (), multiply});
87- auto matmul_2 = wrap_type<ov::op:: v0::MatMul>({concat_2, any_input (is_first_dim_static)});
88- auto add_2 = wrap_type<ov::op:: v1::Add>({matmul_2, any_input ()});
90+ auto concat_2 = pattern::wrap_type< v0::Concat>({pattern:: any_input (), multiply});
91+ auto matmul_2 = pattern::wrap_type< v0::MatMul>({concat_2, pattern:: any_input (is_first_dim_static)});
92+ auto add_2 = pattern::wrap_type< v1::Add>({matmul_2, pattern:: any_input ()});
8993 // only Tanh is supported in the current version of AUGRUCell
90- auto tanh = wrap_type<ov::op:: v0::Tanh>({add_2});
94+ auto tanh = pattern::wrap_type< v0::Tanh>({add_2});
9195
92- auto subtract_1 = wrap_type<ov::op:: v1::Subtract>({any_input (), any_input ()});
93- auto multiply_2 = wrap_type<ov::op:: v1::Multiply>({subtract_1, split});
94- auto subtract_2 = wrap_type<ov::op:: v1::Subtract>({any_input (), multiply_2});
95- auto multiply_3 = wrap_type<ov::op:: v1::Multiply>({subtract_2, tanh});
96+ auto subtract_1 = pattern::wrap_type< v1::Subtract>({pattern:: any_input (), pattern:: any_input ()});
97+ auto multiply_2 = pattern::wrap_type< v1::Multiply>({subtract_1, split});
98+ auto subtract_2 = pattern::wrap_type< v1::Subtract>({pattern:: any_input (), multiply_2});
99+ auto multiply_3 = pattern::wrap_type< v1::Multiply>({subtract_2, tanh});
96100
97- auto multiply_4 = wrap_type<ov::op:: v1::Multiply>({multiply_2, any_input ()});
98- auto add_3 = wrap_type<ov::op:: v1::Add>({multiply_4, multiply_3});
101+ auto multiply_4 = pattern::wrap_type< v1::Multiply>({multiply_2, pattern:: any_input ()});
102+ auto add_3 = pattern::wrap_type< v1::Add>({multiply_4, multiply_3});
99103
100104 matcher_pass_callback callback = [=](pattern::Matcher& m) {
101105 NodeRegistry rg;
@@ -110,35 +114,34 @@ ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
110114 auto hidden_size = h_pshape[1 ].get_length ();
111115 auto input_size = x_pshape[1 ].get_length ();
112116
113- auto axis_0 = rg.make <ov::op:: v0::Constant>(i64 , Shape{}, 0 );
114- auto axis_1 = rg.make <ov::op:: v0::Constant>(i64 , Shape{}, 1 );
117+ auto axis_0 = rg.make <v0::Constant>(element:: i64 , Shape{}, 0 );
118+ auto axis_1 = rg.make <v0::Constant>(element:: i64 , Shape{}, 1 );
115119
116120 auto A = pattern_map.at (subtract_1)->input_value (1 );
117121 // biases are required
118122 auto bias_add_1 = get_bias_add (pattern_map.at (add_1), rg);
119- auto split_bias_r_z = rg.make <ov::op:: v1::Split>(bias_add_1->input_value (1 ), axis_1, 2 );
123+ auto split_bias_r_z = rg.make <v1::Split>(bias_add_1->input_value (1 ), axis_1, 2 );
120124 auto bias_add_2 = get_bias_add (pattern_map.at (add_2), rg);
121125
122- auto B = rg.make <ov::op:: v0::Concat>(
126+ auto B = rg.make <v0::Concat>(
123127 OutputVector{split_bias_r_z->output (1 ), split_bias_r_z->output (0 ), bias_add_2->input_value (1 )},
124128 1 );
125129
126130 auto WRrz = get_weights_matmul (pattern_map.at (matmul_1), rg)->input_value (1 );
127131 auto WRh = get_weights_matmul (pattern_map.at (matmul_2), rg)->input_value (1 );
128132
129- auto split_lenghts = rg.make <ov::op::v0::Constant>(i64 , Shape{2 }, vector<int64_t >{input_size, hidden_size});
130- auto split_WRrz = rg.make <ov::op::v1::VariadicSplit>(WRrz, axis_1, split_lenghts);
131- auto split_W_r_z = rg.make <ov::op::v1::Split>(split_WRrz->output (0 ), axis_0, 2 );
132- auto split_R_r_z = rg.make <ov::op::v1::Split>(split_WRrz->output (1 ), axis_0, 2 );
133- auto split_WRh = rg.make <ov::op::v1::VariadicSplit>(WRh, axis_1, split_lenghts);
134- auto Wzrh = rg.make <ov::op::v0::Concat>(
135- OutputVector{split_W_r_z->output (1 ), split_W_r_z->output (0 ), split_WRh->output (0 )},
136- 0 );
137- auto Rzrh = rg.make <ov::op::v0::Concat>(
138- OutputVector{split_R_r_z->output (1 ), split_R_r_z->output (0 ), split_WRh->output (1 )},
139- 0 );
140-
141- auto squeeze_B = rg.make <ov::op::v0::Squeeze>(B, axis_0);
133+ auto split_lenghts =
134+ rg.make <v0::Constant>(element::i64 , Shape{2 }, std::vector<int64_t >{input_size, hidden_size});
135+ auto split_WRrz = rg.make <v1::VariadicSplit>(WRrz, axis_1, split_lenghts);
136+ auto split_W_r_z = rg.make <v1::Split>(split_WRrz->output (0 ), axis_0, 2 );
137+ auto split_R_r_z = rg.make <v1::Split>(split_WRrz->output (1 ), axis_0, 2 );
138+ auto split_WRh = rg.make <v1::VariadicSplit>(WRh, axis_1, split_lenghts);
139+ auto Wzrh =
140+ rg.make <v0::Concat>(OutputVector{split_W_r_z->output (1 ), split_W_r_z->output (0 ), split_WRh->output (0 )}, 0 );
141+ auto Rzrh =
142+ rg.make <v0::Concat>(OutputVector{split_R_r_z->output (1 ), split_R_r_z->output (0 ), split_WRh->output (1 )}, 0 );
143+
144+ auto squeeze_B = rg.make <v0::Squeeze>(B, axis_0);
142145 auto cell =
143146 rg.make <ov::op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, H.get_partial_shape ()[1 ].get_length ());
144147
@@ -148,6 +151,8 @@ ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
148151 return true ;
149152 };
150153
151- auto m = make_shared<Matcher>(add_3, matcher_name);
154+ auto m = std:: make_shared<pattern:: Matcher>(add_3, matcher_name);
152155 this ->register_matcher (m, callback);
153156}
157+
158+ } // namespace ov::pass
0 commit comments