@@ -45,52 +45,61 @@ struct unary_tensor_expression;
4545
4646namespace boost ::numeric::ublas::detail {
4747
48- template <class T , class E >
48+ template <typename T>
49+ struct is_tensor_type
50+ : std::false_type
51+ {};
52+
53+ template <typename E>
54+ struct is_tensor_type < tensor_core<E> >
55+ : std::true_type
56+ {};
57+
58+ template <class T >
59+ static constexpr bool is_tensor_type_v = is_tensor_type< std::decay_t <T> >::value;
60+
61+ template <typename T>
4962struct has_tensor_types
50- : std::integral_constant< bool , same_exp<T,E> >
63+ : is_tensor_type<T >
5164{};
5265
53- template <class T , class E >
54- static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t <T>, std:: decay_t <E> >::value;
66+ template <class T >
67+ static constexpr bool has_tensor_types_v = has_tensor_types< std::decay_t <T> >::value;
5568
5669template <class T , class D >
57- struct has_tensor_types <T, tensor_expression<T,D>>
58- {
59- static constexpr bool value =
60- same_exp<T,D> ||
61- has_tensor_types<T, std::decay_t <D> >::value;
62- };
70+ struct has_tensor_types < tensor_expression<T,D> >
71+ : has_tensor_types< std::decay_t <D> >
72+ {};
6373
6474template <class T , class EL , class ER , class OP >
65- struct has_tensor_types <T, binary_tensor_expression<T,EL,ER,OP>>
66- {
67- static constexpr bool value =
68- same_exp<T,EL> ||
69- same_exp<T,ER> ||
70- has_tensor_types<T, std::decay_t <EL> >::value ||
71- has_tensor_types<T, std::decay_t <ER> >::value;
72- };
75+ struct has_tensor_types < binary_tensor_expression<T,EL,ER,OP> >
76+ : std::integral_constant< bool , has_tensor_types_v<EL> || has_tensor_types_v<ER> >
77+ {};
7378
7479template <class T , class E , class OP >
75- struct has_tensor_types <T, unary_tensor_expression<T,E,OP>>
76- {
77- static constexpr bool value =
78- same_exp<T,E> ||
79- has_tensor_types<T, std::decay_t <E> >::value;
80- };
80+ struct has_tensor_types < unary_tensor_expression<T,E,OP> >
81+ : has_tensor_types< std::decay_t <E> >
82+ {};
8183
8284} // namespace boost::numeric::ublas::detail
8385
8486
8587namespace boost ::numeric::ublas::detail
8688{
8789
90+
91+ // TODO: remove this place holder for the old ublas expression after we remove the
92+ // support for them.
93+ template <class E >
94+ [[nodiscard]]
95+ constexpr auto & retrieve_extents (ublas_expression<E> const &) noexcept ;
96+
8897/* * @brief Retrieves extents of the tensor_core
8998 *
9099*/
91100template <class TensorEngine >
92101[[nodiscard]]
93- constexpr auto & retrieve_extents (tensor_core<TensorEngine> const & t)
102+ constexpr auto & retrieve_extents (tensor_core<TensorEngine> const & t) noexcept
94103{
95104 return t.extents ();
96105}
@@ -103,17 +112,14 @@ constexpr auto& retrieve_extents(tensor_core<TensorEngine> const& t)
103112*/
104113template <class T , class D >
105114[[nodiscard]]
106- constexpr auto & retrieve_extents (tensor_expression<T,D> const & expr)
115+ constexpr auto & retrieve_extents (tensor_expression<T,D> const & expr) noexcept
107116{
108- static_assert (has_tensor_types_v<T, tensor_expression<T,D>>,
117+ static_assert (has_tensor_types_v<tensor_expression<T,D>>,
109118 " Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
110119
111120 auto const & cast_expr = expr ();
112-
113- if constexpr ( same_exp<T,D> )
114- return cast_expr.extents ();
115- else
116- return retrieve_extents (cast_expr);
121+
122+ return retrieve_extents (cast_expr);
117123}
118124
119125// Disable warning for unreachable code for MSVC compiler
@@ -129,24 +135,24 @@ constexpr auto& retrieve_extents(tensor_expression<T,D> const& expr)
129135*/
130136template <class T , class EL , class ER , class OP >
131137[[nodiscard]]
132- constexpr auto & retrieve_extents (binary_tensor_expression<T,EL,ER,OP> const & expr)
138+ constexpr auto & retrieve_extents (binary_tensor_expression<T,EL,ER,OP> const & expr) noexcept
133139{
134- static_assert (has_tensor_types_v<T, binary_tensor_expression<T,EL,ER,OP>>,
140+ static_assert (has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
135141 " Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
136142
137143 auto const & lexpr = expr.left_expr ();
138144 auto const & rexpr = expr.right_expr ();
139145
140- if constexpr ( same_exp<T, EL> )
141- return lexpr. extents ( );
146+ if constexpr ( is_tensor_type_v< EL> )
147+ return retrieve_extents (lexpr );
142148
143- else if constexpr ( same_exp<T, ER> )
144- return rexpr. extents ( );
149+ else if constexpr ( is_tensor_type_v< ER> )
150+ return retrieve_extents (rexpr );
145151
146- else if constexpr ( has_tensor_types_v<T, EL> )
152+ else if constexpr ( has_tensor_types_v<EL> )
147153 return retrieve_extents (lexpr);
148154
149- else if constexpr ( has_tensor_types_v<T, ER> )
155+ else if constexpr ( has_tensor_types_v<ER> )
150156 return retrieve_extents (rexpr);
151157}
152158
@@ -162,19 +168,15 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
162168*/
163169template <class T , class E , class OP >
164170[[nodiscard]]
165- constexpr auto & retrieve_extents (unary_tensor_expression<T,E,OP> const & expr)
171+ constexpr auto & retrieve_extents (unary_tensor_expression<T,E,OP> const & expr) noexcept
166172{
167173
168- static_assert (has_tensor_types_v<T, unary_tensor_expression<T,E,OP>>,
174+ static_assert (has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
169175 " Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
170176
171177 auto const & uexpr = expr.expr ();
172178
173- if constexpr ( same_exp<T,E> )
174- return uexpr.extents ();
175-
176- else if constexpr ( has_tensor_types_v<T,E> )
177- return retrieve_extents (uexpr);
179+ return retrieve_extents (uexpr);
178180}
179181
180182} // namespace boost::numeric::ublas::detail
@@ -184,89 +186,67 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
184186
185187namespace boost ::numeric::ublas::detail {
186188
189+ // TODO: remove this place holder for the old ublas expression after we remove the
190+ // support for them.
191+ template <class E , std::size_t ... es>
192+ [[nodiscard]] inline
193+ constexpr auto all_extents_equal (ublas_expression<E> const &, extents<es...> const &) noexcept
194+ {
195+ return true ;
196+ }
197+
187198template <class EN , std::size_t ... es>
188199[[nodiscard]] inline
189- constexpr auto all_extents_equal (tensor_core<EN> const & t, extents<es...> const & e)
200+ constexpr auto all_extents_equal (tensor_core<EN> const & t, extents<es...> const & e) noexcept
190201{
191202 return ::operator ==(e,t.extents ());
192203}
193204
194205template <class T , class D , std::size_t ... es>
195206[[nodiscard]]
196- constexpr auto all_extents_equal (tensor_expression<T,D> const & expr, extents<es...> const & e)
207+ constexpr auto all_extents_equal (tensor_expression<T,D> const & expr, extents<es...> const & e) noexcept
197208{
198209
199- static_assert (has_tensor_types_v<T, tensor_expression<T,D>>,
210+ static_assert (has_tensor_types_v<tensor_expression<T,D>>,
200211 " Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
201212
202213 auto const & cast_expr = expr ();
203214
204- using ::operator ==;
205- using ::operator !=;
206-
207- if constexpr ( same_exp<T,D> )
208- if ( e != cast_expr.extents () )
209- return false ;
210-
211- if constexpr ( has_tensor_types_v<T,D> )
212- if ( !all_extents_equal (cast_expr, e))
213- return false ;
215+ if ( !all_extents_equal (cast_expr, e) )
216+ return false ;
214217
215218 return true ;
216219
217220}
218221
219222template <class T , class EL , class ER , class OP , std::size_t ... es>
220223[[nodiscard]]
221- constexpr auto all_extents_equal (binary_tensor_expression<T,EL,ER,OP> const & expr, extents<es...> const & e)
224+ constexpr auto all_extents_equal (binary_tensor_expression<T,EL,ER,OP> const & expr, extents<es...> const & e) noexcept
222225{
223- static_assert (has_tensor_types_v<T, binary_tensor_expression<T,EL,ER,OP>>,
226+ static_assert (has_tensor_types_v<binary_tensor_expression<T,EL,ER,OP>>,
224227 " Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
225228
226- using ::operator ==;
227- using ::operator !=;
228-
229229 auto const & lexpr = expr.left_expr ();
230230 auto const & rexpr = expr.right_expr ();
231231
232- if constexpr ( same_exp<T,EL> )
233- if (e != lexpr.extents ())
234- return false ;
235-
236- if constexpr ( same_exp<T,ER> )
237- if (e != rexpr.extents ())
238- return false ;
239-
240- if constexpr ( has_tensor_types_v<T,EL> )
241- if (!all_extents_equal (lexpr, e))
242- return false ;
243-
244- if constexpr ( has_tensor_types_v<T,ER> )
245- if (!all_extents_equal (rexpr, e))
246- return false ;
232+ if ( !all_extents_equal (lexpr, e) || !all_extents_equal (rexpr, e) )
233+ return false ;
247234
248235 return true ;
249236}
250237
251238
252239template <class T , class E , class OP , std::size_t ... es>
253240[[nodiscard]]
254- constexpr auto all_extents_equal (unary_tensor_expression<T,E,OP> const & expr, extents<es...> const & e)
241+ constexpr auto all_extents_equal (unary_tensor_expression<T,E,OP> const & expr, extents<es...> const & e) noexcept
255242{
256- static_assert (has_tensor_types_v<T, unary_tensor_expression<T,E,OP>>,
243+ static_assert (has_tensor_types_v<unary_tensor_expression<T,E,OP>>,
257244 " Error in boost::numeric::ublas::all_extents_equal: Expression to evaluate should contain tensors." );
258245
259- using ::operator ==;
260-
261246 auto const & uexpr = expr.expr ();
262247
263- if constexpr ( same_exp<T,E> )
264- if (e != uexpr.extents ())
265- return false ;
266-
267- if constexpr ( has_tensor_types_v<T,E> )
268- if (!all_extents_equal (uexpr, e))
269- return false ;
248+ if ( !all_extents_equal (uexpr, e) )
249+ return false ;
270250
271251 return true ;
272252}
0 commit comments