Skip to content

Commit 98dc3b6

Browse files
committed
Fix 1-element ec ambiguities
1 parent bf7fe2a commit 98dc3b6

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

include/simsycl/sycl/vec.hh

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "../detail/check.hh"
88
#include "../detail/utils.hh"
99

10+
#include <concepts>
1011
#include <cstdint>
1112
#include <cstdlib>
1213
#include <type_traits>
@@ -234,8 +235,9 @@ class swizzled_vec {
234235
swizzled_vec &operator=(const swizzled_vec &) = delete;
235236
swizzled_vec &operator=(swizzled_vec &&) = delete;
236237

237-
swizzled_vec &operator=(const value_type &rhs)
238-
requires(allow_assign)
238+
template<typename T>
239+
swizzled_vec &operator=(const T &rhs)
240+
requires(allow_assign && std::convertible_to<T, value_type>)
239241
{
240242
for(size_t i = 0; i < num_elements; ++i) { m_elems[indices[i]] = rhs; }
241243
return *this;
@@ -540,7 +542,10 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
540542
vec(const vec &) = default;
541543
vec &operator=(const vec &rhs) = default;
542544

543-
vec &operator=(const DataT &rhs) {
545+
template<typename T>
546+
vec &operator=(const T &rhs)
547+
requires(std::convertible_to<T, DataT>)
548+
{
544549
for(int i = 0; i < NumElements; ++i) { m_elems[i] = rhs; }
545550
return *this;
546551
}
@@ -755,15 +760,17 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
755760
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs.m_elems[i]; } \
756761
return result; \
757762
} \
758-
friend vec operator op(const vec &lhs, const DataT &rhs) \
759-
requires(enable_if) \
763+
template<typename T> \
764+
friend vec operator op(const vec &lhs, const T &rhs) \
765+
requires(enable_if && std::convertible_to<T, DataT>) \
760766
{ \
761767
vec result; \
762768
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs; } \
763769
return result; \
764770
} \
765-
friend vec operator op(const DataT &lhs, const vec &rhs) \
766-
requires(enable_if) \
771+
template<typename T> \
772+
friend vec operator op(const T &lhs, const vec &rhs) \
773+
requires(enable_if && std::convertible_to<T, DataT>) \
767774
{ \
768775
vec result; \
769776
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs op rhs.m_elems[i]; } \
@@ -797,8 +804,9 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
797804
for(int i = 0; i < NumElements; ++i) { lhs.m_elems[i] op rhs.m_elems[rhs.indices[i]]; } \
798805
return lhs; \
799806
} \
800-
friend vec &operator op(vec & lhs, const DataT & rhs) \
801-
requires(enable_if) \
807+
template<typename T> \
808+
friend vec &operator op(vec & lhs, const T & rhs) \
809+
requires(enable_if && std::convertible_to<T, DataT>) \
802810
{ \
803811
for(int i = 0; i < NumElements; ++i) { lhs.m_elems[i] op rhs; } \
804812
return lhs; \
@@ -866,12 +874,18 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
866874
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs.m_elems[i]; } \
867875
return result; \
868876
} \
869-
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const vec & lhs, const DataT & rhs) { \
877+
template<typename T> \
878+
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const vec & lhs, const T & rhs) \
879+
requires(std::convertible_to<T, DataT>) \
880+
{ \
870881
vec<decltype(DataT {} op DataT{}), NumElements> result; \
871882
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs.m_elems[i] op rhs; } \
872883
return result; \
873884
} \
874-
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const DataT & lhs, const vec & rhs) { \
885+
template<typename T> \
886+
friend vec<decltype(DataT {} op DataT{}), NumElements> operator op(const T & lhs, const vec & rhs) \
887+
requires(std::convertible_to<T, DataT>) \
888+
{ \
875889
vec<decltype(DataT {} op DataT{}), NumElements> result; \
876890
for(int i = 0; i < NumElements; ++i) { result.m_elems[i] = lhs op rhs.m_elems[i]; } \
877891
return result; \

0 commit comments

Comments
 (0)