Skip to content

Commit cfff4ed

Browse files
committed
Issue/888 - Add gt,lt,ge,le,eq,ne,logical_and,logical_or,logical_xor,sin,bitwise_and, bitwise_or, bitwise_xor, bitwise_left_shift, bitwise_right_shift,floor_divide,atan2,exp2,log2,log10,rsqrt,square,hypot,copysign,remainder,isnan,isfinite,isinf,sinc,fmin,fmax,log1p binary operators.
1 parent 05096ea commit cfff4ed

File tree

202 files changed

+4771
-55
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

202 files changed

+4771
-55
lines changed

include/infiniop/ops/binary_ops_api.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,30 @@
1515

1616
// Declare all binary operator APIs
1717
BINARY_OP_API_DECLARE(div, Div)
18+
BINARY_OP_API_DECLARE(floor_divide, FloorDivide)
1819
BINARY_OP_API_DECLARE(pow, Pow)
20+
BINARY_OP_API_DECLARE(copysign, CopySign)
21+
BINARY_OP_API_DECLARE(hypot, Hypot)
22+
BINARY_OP_API_DECLARE(atan2, Atan2)
1923
BINARY_OP_API_DECLARE(mod, Mod)
24+
BINARY_OP_API_DECLARE(remainder, Remainder)
2025
BINARY_OP_API_DECLARE(max, Max)
2126
BINARY_OP_API_DECLARE(min, Min)
27+
BINARY_OP_API_DECLARE(fmax, Fmax)
28+
BINARY_OP_API_DECLARE(fmin, Fmin)
29+
BINARY_OP_API_DECLARE(gt, Gt)
30+
BINARY_OP_API_DECLARE(lt, Lt)
31+
BINARY_OP_API_DECLARE(ge, Ge)
32+
BINARY_OP_API_DECLARE(le, Le)
33+
BINARY_OP_API_DECLARE(eq, Eq)
34+
BINARY_OP_API_DECLARE(ne, Ne)
35+
BINARY_OP_API_DECLARE(logical_and, LogicalAnd)
36+
BINARY_OP_API_DECLARE(logical_or, LogicalOr)
37+
BINARY_OP_API_DECLARE(logical_xor, LogicalXor)
38+
BINARY_OP_API_DECLARE(bitwise_and, BitwiseAnd)
39+
BINARY_OP_API_DECLARE(bitwise_or, BitwiseOr)
40+
BINARY_OP_API_DECLARE(bitwise_xor, BitwiseXor)
41+
BINARY_OP_API_DECLARE(bitwise_left_shift, BitwiseLeftShift)
42+
BINARY_OP_API_DECLARE(bitwise_right_shift, BitwiseRightShift)
2243

2344
#endif // __INFINIOP_BINARY_OPS_API_H__

include/infiniop/ops/unary_ops_api.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
// Declare all unary operator APIs
1818
UNARY_OP_API_DECLARE(abs, Abs)
1919
UNARY_OP_API_DECLARE(log, Log)
20+
UNARY_OP_API_DECLARE(log2, Log2)
21+
UNARY_OP_API_DECLARE(log10, Log10)
22+
UNARY_OP_API_DECLARE(log1p, Log1p)
2023
UNARY_OP_API_DECLARE(sqrt, Sqrt)
24+
UNARY_OP_API_DECLARE(square, Square)
25+
UNARY_OP_API_DECLARE(rsqrt, Rsqrt)
2126
UNARY_OP_API_DECLARE(reciprocal, Reciprocal)
2227
UNARY_OP_API_DECLARE(neg, Neg)
2328
UNARY_OP_API_DECLARE(round, Round)
@@ -36,6 +41,12 @@ UNARY_OP_API_DECLARE(atan, Atan)
3641
UNARY_OP_API_DECLARE(acos, Acos)
3742
UNARY_OP_API_DECLARE(ceil, Ceil)
3843
UNARY_OP_API_DECLARE(exp, Exp)
44+
UNARY_OP_API_DECLARE(exp2, Exp2)
3945
UNARY_OP_API_DECLARE(hardswish, Hardswish)
46+
UNARY_OP_API_DECLARE(isnan, IsNan)
47+
UNARY_OP_API_DECLARE(isinf, IsInf)
48+
UNARY_OP_API_DECLARE(isfinite, IsFinite)
49+
UNARY_OP_API_DECLARE(sinc, Sinc)
50+
UNARY_OP_API_DECLARE(sin, Sin)
4051

4152
#endif // __INFINIOP_UNARY_OPS_API_H__

src/infiniop/elementwise/binary.h

Lines changed: 542 additions & 11 deletions
Large diffs are not rendered by default.

src/infiniop/elementwise/cpu/elementwise_cpu_impl.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@
4848
case INFINI_DTYPE_BF16: \
4949
return _device_info->template calculate<Op, bf16_t>(_info, output, inputs, stream);
5050

51+
/**
52+
* @brief Integral Calculate Switch Cases (I32, I64, U8)
53+
* For bitwise operations that only support integral types
54+
*/
55+
#define _IMPL_CALC_CASES_INTEGRAL \
56+
case INFINI_DTYPE_I32: \
57+
return _device_info->template calculate<Op, int32_t>(_info, output, inputs, stream); \
58+
case INFINI_DTYPE_I64: \
59+
return _device_info->template calculate<Op, int64_t>(_info, output, inputs, stream); \
60+
case INFINI_DTYPE_U8: \
61+
return _device_info->template calculate<Op, uint8_t>(_info, output, inputs, stream);
62+
5163
/**
5264
* @brief Generic Template for the Calculate method
5365
* @param CASES_MACRO The macro containing the switch cases to use
@@ -156,4 +168,27 @@
156168
) \
157169
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_EXTENDED)
158170

171+
/**
172+
* @brief Implementation for Binary Operators with Integral Types (I32, I64, U8)
173+
*
174+
* This macro generates the Descriptor destructor, create, and calculate methods
175+
* for binary operators that only support integral types (e.g., bitwise operations).
176+
*
177+
* Usage:
178+
* namespace op::bitwise_and::cpu {
179+
* using Op = op::elementwise::binary::BinaryOp<BinaryMode::BitwiseAnd>;
180+
* ELEMENTWISE_CPU_IMPL_BINARY_INTEGRAL(bitwise_and)
181+
* }
182+
*/
183+
#define ELEMENTWISE_CPU_IMPL_BINARY_INTEGRAL(OP) \
184+
_IMPL_CREATE_METHOD( \
185+
const auto &a_desc = input_desc_vec.at(0); \
186+
const auto &b_desc = input_desc_vec.at(1); \
187+
const auto &a_shape = a_desc->shape(); \
188+
const auto &b_shape = b_desc->shape(); \
189+
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape);, \
190+
INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_U8 \
191+
) \
192+
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_INTEGRAL)
193+
159194
#endif // __INFINIOP_ELEMENTWISE_CPU_IMPL_H__

src/infiniop/elementwise/nvidia/elementwise_nvidia_impl.cuh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@
5252
case INFINI_DTYPE_F64: \
5353
return _device_info->calculate<256, cuda::Op, double>(_info, workspace, output, inputs, stream);
5454

55+
/**
56+
* @brief Integral Calculate Switch Cases (I32, I64, U8)
57+
* For bitwise operations that only support integral types
58+
*/
59+
#define _IMPL_CALC_CASES_INTEGRAL \
60+
case INFINI_DTYPE_I32: \
61+
return _device_info->calculate<256, cuda::Op, int32_t>(_info, workspace, output, inputs, stream); \
62+
case INFINI_DTYPE_I64: \
63+
return _device_info->calculate<256, cuda::Op, int64_t>(_info, workspace, output, inputs, stream); \
64+
case INFINI_DTYPE_U8: \
65+
return _device_info->calculate<256, cuda::Op, uint8_t>(_info, workspace, output, inputs, stream);
66+
5567
/**
5668
* @brief Generic Template for the Calculate method
5769
* @param CASES_MACRO The macro containing the switch cases to use
@@ -160,4 +172,26 @@
160172
) \
161173
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_EXTENDED)
162174

175+
/**
176+
* @brief Implementation for Binary Operators with Integral Types (I32, I64, U8)
177+
*
178+
* This macro generates the Descriptor destructor, create, and calculate methods
179+
* for binary operators that only support integral types (e.g., bitwise operations).
180+
*
181+
* Usage:
182+
* namespace op::bitwise_and::nvidia {
183+
* ELEMENTWISE_NVIDIA_IMPL_BINARY_INTEGRAL(bitwise_and)
184+
* }
185+
*/
186+
#define ELEMENTWISE_NVIDIA_IMPL_BINARY_INTEGRAL(OP) \
187+
_IMPL_CREATE_METHOD( \
188+
const auto &a_desc = input_desc_vec.at(0); \
189+
const auto &b_desc = input_desc_vec.at(1); \
190+
const auto &a_shape = a_desc->shape(); \
191+
const auto &b_shape = b_desc->shape(); \
192+
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape);, \
193+
INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_U8 \
194+
) \
195+
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_INTEGRAL)
196+
163197
#endif // __INFINIOP_ELEMENTWISE_NVIDIA_IMPL_CUH__

0 commit comments

Comments
 (0)