@@ -12,62 +12,56 @@ infiniStatus_t Descriptor::create(
1212 Descriptor **desc_ptr,
1313 infiniopTensorDescriptor_t c_desc,
1414 infiniopTensorDescriptor_t a_desc,
15- infiniopTensorDescriptor_t b_desc
16- ) {
15+ infiniopTensorDescriptor_t b_desc) {
1716 auto handle = reinterpret_cast <device::cpu::Handle *>(handle_);
1817
19- // --------------------- start: check data type and calculate workspace size ----------------------
18+ // --------------------- start: check data type and calculate workspace size ----------------------
2019 auto dtype = c_desc->dtype ();
2120 CHECK_DTYPE (dtype, INFINI_DTYPE_BOOL);
2221 CHECK_OR_RETURN (b_desc->dtype () == a_desc->dtype (), INFINI_STATUS_BAD_TENSOR_DTYPE);
2322 size_t WorkSpaceSize = 0 ;
24- // ---------------------- end: check data type and calculate workspace size -----------------------
23+ // ---------------------- end: check data type and calculate workspace size -----------------------
2524
2625 auto result = AllEqualInfo::createAllEqualInfo (
2726 c_desc,
2827 a_desc,
29- b_desc
30- );
28+ b_desc);
3129 CHECK_RESULT (result);
3230 const AllEqualInfo &info = result.take ();
33-
31+
3432 *desc_ptr = new Descriptor (
3533 dtype, std::move (info), WorkSpaceSize,
3634 nullptr ,
37- handle->device , handle->device_id
38- );
35+ handle->device , handle->device_id );
3936
4037 return INFINI_STATUS_SUCCESS;
4138}
4239
43-
4440infiniStatus_t Descriptor::calculate (
4541 void *workspace,
4642 size_t workspace_size,
47- void * c,
48- const void * a,
49- const void * b,
50- void *stream
51- ) const {
43+ void *c,
44+ const void *a,
45+ const void *b,
46+ void *stream) const {
5247 std::vector<ptrdiff_t > contiguous_strides (_info.ndim );
53- ptrdiff_t last_dim = 1 ;
48+ ptrdiff_t last_dim = 1 ;
5449 ptrdiff_t last_stride = 1 ;
55- for (size_t d = 0 ; d < _info.ndim ; d ++)
56- {
57- contiguous_strides[d] = last_dim * last_stride;
50+ for (size_t d = 0 ; d < _info.ndim ; d++) {
51+ contiguous_strides[d] = last_dim * last_stride;
5852 last_dim = _info.a_shape [d];
5953 last_stride = contiguous_strides[d];
6054 }
6155 size_t total_size = last_dim * last_stride;
6256 size_t elem_size = infiniSizeOf (_info.dtype );
63- auto c_ptr = reinterpret_cast <bool *>(c);
57+ auto c_ptr = reinterpret_cast <bool *>(c);
6458 *c_ptr = true ;
65- #pragma omp parallel for
66- for (size_t i = 0 ; i < total_size; i ++) {
67- auto a_ptr = reinterpret_cast <const char *>(a);
68- auto b_ptr = reinterpret_cast <const char *>(b);
59+ #pragma omp parallel for
60+ for (size_t i = 0 ; i < total_size; i++) {
61+ auto a_ptr = reinterpret_cast <const char *>(a);
62+ auto b_ptr = reinterpret_cast <const char *>(b);
6963 size_t rem = i;
70- for (int d = _info.ndim - 1 ; d >= 0 ; d --) {
64+ for (int d = _info.ndim - 1 ; d >= 0 ; d--) {
7165 size_t dim_index = rem / contiguous_strides[d];
7266 rem = rem % contiguous_strides[d];
7367 a_ptr += dim_index * _info.a_strides [d];
@@ -79,4 +73,4 @@ infiniStatus_t Descriptor::calculate(
7973 }
8074 return INFINI_STATUS_SUCCESS;
8175}
82- }
76+ } // namespace op::all_equal::cpu
0 commit comments