Skip to content

Commit 9230762

Browse files
zhuyuegongchensu
authored andcommitted
Change operator name from equal to all_equal.
1 parent 5c9c0e1 commit 9230762

File tree

19 files changed

+152
-134
lines changed

19 files changed

+152
-134
lines changed

include/infiniop/ops/all_equal.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef __INFINIOP_ALL_EQUAL_API_H__
2+
#define __INFINIOP_ALL_EQUAL_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopAllEqualDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateAllEqualDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopAllEqualDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t c_desc,
12+
infiniopTensorDescriptor_t a_desc,
13+
infiniopTensorDescriptor_t b_desc
14+
);
15+
16+
__C __export infiniStatus_t infiniopGetAllEqualWorkspaceSize(infiniopAllEqualDescriptor_t desc, size_t *size);
17+
18+
__C __export infiniStatus_t infiniopAllEqual(
19+
infiniopAllEqualDescriptor_t desc,
20+
void *workspace,
21+
size_t workspace_size,
22+
void * c,
23+
const void * a,
24+
const void * b,
25+
void *stream
26+
);
27+
28+
__C __export infiniStatus_t infiniopDestroyAllEqualDescriptor(infiniopAllEqualDescriptor_t desc);
29+
30+
#endif

include/infiniop/ops/equal.h

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/infiniop-test/include/ops.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ DECLARE_INFINIOP_TEST(add)
1616
DECLARE_INFINIOP_TEST(causal_softmax)
1717
DECLARE_INFINIOP_TEST(rearrange)
1818
DECLARE_INFINIOP_TEST(sub)
19-
DECLARE_INFINIOP_TEST(equal)
19+
DECLARE_INFINIOP_TEST(all_equal)
2020

2121
#define REGISTER_INFINIOP_TEST(name) \
2222
{ \
@@ -44,7 +44,7 @@ DECLARE_INFINIOP_TEST(equal)
4444
REGISTER_INFINIOP_TEST(causal_softmax) \
4545
REGISTER_INFINIOP_TEST(rearrange) \
4646
REGISTER_INFINIOP_TEST(sub) \
47-
REGISTER_INFINIOP_TEST(equal) \
47+
REGISTER_INFINIOP_TEST(all_equal) \
4848
}
4949

5050
namespace infiniop_test {
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
#include <infinirt.h>
44
#include <iomanip>
55
#include <iostream>
6+
#include "../../../include/infiniop/ops/all_equal.h"
67

7-
namespace infiniop_test::equal {
8+
namespace infiniop_test::all_equal {
89
struct Test::Attributes {
910
std::shared_ptr<Tensor> a;
1011
std::shared_ptr<Tensor> b;
@@ -35,22 +36,22 @@ std::shared_ptr<Test> Test::build(
3536

3637
std::shared_ptr<infiniop_test::Result> Test::run(
3738
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
38-
infiniopEqualDescriptor_t op_desc;
39+
infiniopAllEqualDescriptor_t op_desc;
3940
auto a = _attributes->a->to(device, device_id);
4041
auto b = _attributes->b->to(device, device_id);
4142
auto c = _attributes->c->to(device, device_id);
42-
CHECK_OR(infiniopCreateEqualDescriptor(handle, &op_desc,
43+
CHECK_OR(infiniopCreateAllEqualDescriptor(handle, &op_desc,
4344
c->desc(),
4445
a->desc(),
4546
b->desc()),
4647
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
4748
size_t workspace_size;
48-
CHECK_OR(infiniopGetEqualWorkspaceSize(op_desc, &workspace_size),
49+
CHECK_OR(infiniopGetAllEqualWorkspaceSize(op_desc, &workspace_size),
4950
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
5051
void *workspace;
5152
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
5253
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
53-
CHECK_OR(infiniopEqual(op_desc, workspace, workspace_size,
54+
CHECK_OR(infiniopAllEqual(op_desc, workspace, workspace_size,
5455
c->data(),
5556
a->data(),
5657
b->data(),
@@ -67,7 +68,7 @@ std::shared_ptr<infiniop_test::Result> Test::run(
6768

6869
elapsed_time = benchmark(
6970
[=]() {
70-
infiniopEqual(
71+
infiniopAllEqual(
7172
op_desc, workspace, workspace_size,
7273
c->data(),
7374
a->data(),
@@ -106,4 +107,4 @@ Test::~Test() {
106107
delete _attributes;
107108
}
108109

109-
} // namespace infiniop_test::equal
110+
} // namespace infiniop_test::all_equal
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
#ifndef __EQUAL_H__
2-
#define __EQUAL_H__
1+
#ifndef __ALL_EQUAL_H__
2+
#define __ALL_EQUAL_H__
33

44
#include "../../../utils.h"
55
#include "../../operator.h"
66
#include "../../tensor.h"
77
#include "info.h"
88

99
#define DESCRIPTOR(NAMESPACE) \
10-
namespace op::equal::NAMESPACE { \
10+
namespace op::all_equal::NAMESPACE { \
1111
class Descriptor final : public InfiniopDescriptor { \
1212
struct Opaque; \
1313
Opaque *_opaque; \
14-
EqualInfo _info; \
14+
op::all_equal::AllEqualInfo _info; \
1515
size_t _workspace_size; \
1616
Descriptor( \
1717
infiniDtype_t dtype, \
18-
EqualInfo info, \
18+
op::all_equal::AllEqualInfo info, \
1919
size_t workspace_size_, \
2020
Opaque *opaque, \
2121
infiniDevice_t device_type, \

src/infiniop/ops/equal/cpu/equal_cpu.cc renamed to src/infiniop/ops/all_equal/cpu/all_equal_cpu.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#include "equal_cpu.h"
1+
#include "all_equal_cpu.h"
22
#include "../../../devices/cpu/common_cpu.h"
33
#include "../../../reduce/cpu/reduce.h"
44
#include "../info.h"
55

6-
namespace op::equal::cpu {
6+
namespace op::all_equal::cpu {
77

88
Descriptor::~Descriptor() = default;
99

@@ -23,13 +23,13 @@ infiniStatus_t Descriptor::create(
2323
size_t WorkSpaceSize = 0;
2424
// ---------------------- end: check data type and calculate workspace size -----------------------
2525

26-
auto result = EqualInfo::createEqualInfo(
26+
auto result = AllEqualInfo::createAllEqualInfo(
2727
c_desc,
2828
a_desc,
2929
b_desc
3030
);
3131
CHECK_RESULT(result);
32-
const EqualInfo &info = result.take();
32+
const AllEqualInfo &info = result.take();
3333

3434
*desc_ptr = new Descriptor(
3535
dtype, std::move(info), WorkSpaceSize,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef __ALL_EQUAL_CPU_H__
2+
#define __ALL_EQUAL_CPU_H__
3+
4+
#include "../all_equal.h"
5+
6+
DESCRIPTOR(cpu)
7+
8+
9+
#endif // __ALL_EQUAL_CPU_H__
Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#ifndef __EQUAL_KERNEL_CUH__
2-
#define __EQUAL_KERNEL_CUH__
1+
#ifndef __ALL_EQUAL_KERNEL_CUH__
2+
#define __ALL_EQUAL_KERNEL_CUH__
33
// ------------------------------- start: perform operator on CUDA --------------------------------
44
template <unsigned int BLOCK_SIZE, typename Tdata>
5-
__device__ void equalKernel(
5+
__device__ void allEqualKernel(
66
bool * c,
77
const Tdata * a,
88
const Tdata * b,
@@ -12,11 +12,16 @@ __device__ void equalKernel(
1212
ptrdiff_t* a_strides,
1313
ptrdiff_t* b_strides
1414
) {
15-
if (threadIdx.x == 0)
16-
{
17-
*c = true;
15+
// 使用共享内存来避免竞态条件
16+
__shared__ bool block_result;
17+
18+
if (threadIdx.x == 0) {
19+
block_result = true;
1820
}
1921
__syncthreads();
22+
23+
// 每个线程检查自己负责的元素
24+
bool thread_result = true;
2025
for(size_t i = threadIdx.x; i < total_size; i += BLOCK_SIZE) {
2126
auto a_ptr = a;
2227
auto b_ptr = b;
@@ -27,12 +32,24 @@ __device__ void equalKernel(
2732
a_ptr += dim_index * a_strides[d];
2833
b_ptr += dim_index * b_strides[d];
2934
}
30-
if ((*a_ptr != *b_ptr) && (*c == true)) {
31-
*c = false;
35+
if (*a_ptr != *b_ptr) {
36+
thread_result = false;
37+
break; // 发现不匹配,提前退出
3238
}
33-
39+
}
40+
41+
// 使用原子操作来安全地更新结果
42+
if (!thread_result) {
43+
atomicAnd((int*)&block_result, 0);
44+
}
45+
46+
__syncthreads();
47+
48+
// 只有第一个线程写入最终结果
49+
if (threadIdx.x == 0) {
50+
*c = block_result;
3451
}
3552
}
3653
// -------------------------------- end: perform operator on CUDA ---------------------------------
3754

38-
#endif // __EQUAL_KERNEL_CUH__
55+
#endif // __ALL_EQUAL_KERNEL_CUH__
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
#include "../../operator.h"
66
#include "../../tensor.h"
77

8-
namespace op::equal {
8+
namespace op::all_equal {
99

10-
class EqualInfo {
10+
class AllEqualInfo {
1111
private:
12-
EqualInfo() = default;
12+
AllEqualInfo() = default;
1313

1414
public:
1515
// ---------------------------- start: define member variables of Info ----------------------------
@@ -21,7 +21,7 @@ class EqualInfo {
2121

2222
// ----------------------------- end: define member variables of Info -----------------------------
2323

24-
static utils::Result<EqualInfo> createEqualInfo(
24+
static utils::Result<AllEqualInfo> createAllEqualInfo(
2525
infiniopTensorDescriptor_t c_desc,
2626
infiniopTensorDescriptor_t a_desc,
2727
infiniopTensorDescriptor_t b_desc
@@ -30,7 +30,7 @@ class EqualInfo {
3030
CHECK_OR_RETURN(c_desc->ndim() == 1 && c_desc->dim(0) == 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
3131
CHECK_SAME_SHAPE(a_desc->shape(), b_desc->shape());
3232
// -------------------------- end: check tensor shape and input validity --------------------------
33-
return utils::Result<EqualInfo>(EqualInfo{
33+
return utils::Result<AllEqualInfo>(AllEqualInfo{
3434
// ------------------------------ start: create an instance of Info -------------------------------
3535
a_desc->ndim(),
3636
a_desc->dtype(),
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __ALL_EQUAL_METAX_H__
2+
#define __ALL_EQUAL_METAX_H__
3+
4+
#include "../all_equal.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif // __ALL_EQUAL_METAX_H__

0 commit comments

Comments
 (0)