Skip to content

Commit 4a69750

Browse files
authored
Add Half (float16) support to slim ScalarType enum (#18959) (#18959)
Summary: The CUDA runtime shims for sort operations use Half (float16) dtype, but it was not defined in the slim ScalarType enum, causing compiler warnings treated as errors (-Werror=switch). This adds proper Half support to the slim ScalarType enum so switch statements can use the enum value directly instead of casting to the underlying type. Differential Revision: D101218928
1 parent f9f29e7 commit 4a69750

3 files changed

Lines changed: 48 additions & 6 deletions

File tree

backends/aoti/slim/c10/core/ScalarType.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ enum class ScalarType : int8_t {
2828
Short = 2, // int16_t
2929
Int = 3, // int32_t
3030
Long = 4, // int64_t
31-
// Half = 5, // float16 - not currently needed
31+
Half = 5, // float16
3232
Float = 6, // float
3333
// Double = 7, // double - not currently needed
3434
// ComplexHalf = 8,
@@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char;
4848
constexpr ScalarType kShort = ScalarType::Short;
4949
constexpr ScalarType kInt = ScalarType::Int;
5050
constexpr ScalarType kLong = ScalarType::Long;
51+
constexpr ScalarType kHalf = ScalarType::Half;
5152
constexpr ScalarType kFloat = ScalarType::Float;
5253
constexpr ScalarType kBool = ScalarType::Bool;
5354
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
@@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) {
6768
return sizeof(int32_t);
6869
case ScalarType::Long:
6970
return sizeof(int64_t);
71+
case ScalarType::Half:
72+
return 2; // sizeof(__half) = 2 bytes
7073
case ScalarType::Float:
7174
return sizeof(float);
7275
case ScalarType::Bool:
@@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) {
9396
return "Int";
9497
case ScalarType::Long:
9598
return "Long";
99+
case ScalarType::Half:
100+
return "Half";
96101
case ScalarType::Float:
97102
return "Float";
98103
case ScalarType::Bool:
@@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) {
110115
/// @param t The scalar type to check.
111116
/// @return true if the scalar type is floating point, false otherwise.
112117
inline bool isFloatingType(ScalarType t) {
113-
return t == ScalarType::Float || t == ScalarType::BFloat16;
118+
return t == ScalarType::Half || t == ScalarType::Float ||
119+
t == ScalarType::BFloat16;
114120
}
115121

116122
/// Checks if the scalar type is an integral type (including bool optionally).
@@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) {
149155
case ScalarType::Short:
150156
case ScalarType::Int:
151157
case ScalarType::Long:
158+
case ScalarType::Half:
152159
case ScalarType::Float:
153160
case ScalarType::Bool:
154161
case ScalarType::BFloat16:

backends/aoti/slim/c10/core/test/test_scalar_type.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ const std::vector<ScalarTypeTestData> kAllScalarTypes = {
3636
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
3737
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
3838
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
39+
{ScalarType::Half, 5, 2, "Half", true, false, false, false},
3940
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
4041
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
4142
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
@@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) {
128129
EXPECT_EQ(kLong, ScalarType::Long);
129130
}
130131

132+
TEST_F(ScalarTypeConstantsTest, KHalfConstant) {
133+
EXPECT_EQ(kHalf, ScalarType::Half);
134+
}
135+
131136
TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
132137
EXPECT_EQ(kFloat, ScalarType::Float);
133138
}
@@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
185190
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
186191
}
187192

193+
TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) {
194+
EXPECT_EQ(elementSize(ScalarType::Half), 2);
195+
}
196+
188197
TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
189198
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
190199
}
@@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
196205
TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
197206
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
198207
}
208+
209+
// =============================================================================
210+
// isValidScalarType Tests
211+
// =============================================================================
212+
213+
class IsValidScalarTypeTest : public ::testing::Test {};
214+
215+
TEST_F(IsValidScalarTypeTest, HalfIsValid) {
216+
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
217+
}
218+
219+
TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) {
220+
EXPECT_TRUE(isValidScalarType(ScalarType::Byte));
221+
EXPECT_TRUE(isValidScalarType(ScalarType::Char));
222+
EXPECT_TRUE(isValidScalarType(ScalarType::Short));
223+
EXPECT_TRUE(isValidScalarType(ScalarType::Int));
224+
EXPECT_TRUE(isValidScalarType(ScalarType::Long));
225+
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
226+
EXPECT_TRUE(isValidScalarType(ScalarType::Float));
227+
EXPECT_TRUE(isValidScalarType(ScalarType::Bool));
228+
EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16));
229+
}
230+
231+
TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) {
232+
EXPECT_FALSE(isValidScalarType(ScalarType::Undefined));
233+
}

backends/cuda/runtime/shims/sort.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace executorch::backends::cuda {
2424

2525
namespace c10_slim = executorch::backends::aoti::slim::c10;
2626

27-
// PyTorch ScalarType::Half = 5, not defined in slim ScalarType enum.
28-
constexpr auto kHalf = static_cast<c10_slim::ScalarType>(5);
27+
// PyTorch ScalarType::Half = 5, now defined in slim ScalarType enum.
28+
using c10_slim::kHalf;
2929

3030
namespace {
3131

@@ -188,7 +188,7 @@ AOTITorchError aoti_torch_cuda_sort_stable(
188188
case c10_slim::ScalarType::BFloat16:
189189
elem_size = sizeof(__nv_bfloat16);
190190
break;
191-
case kHalf:
191+
case c10_slim::ScalarType::Half:
192192
elem_size = sizeof(__half);
193193
break;
194194
default:
@@ -387,7 +387,7 @@ AOTITorchError aoti_torch_cuda_sort_stable(
387387
stream);
388388
break;
389389
}
390-
case kHalf: {
390+
case c10_slim::ScalarType::Half: {
391391
sort_slice_impl(
392392
static_cast<__half*>(values_base) + offset,
393393
idx_ptr,

0 commit comments

Comments
 (0)