Skip to content

Commit e6e2cc2

Browse files
authored
1 parent c6aeb91 commit e6e2cc2

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

include/cutlass/detail/collective/mixed_input_utils.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl<
347347
// Specialization for INT8 -> BF16 with [3120] value order
348348
template <>
349349
struct LayoutAwareConvertImpl<
350-
cutlass::int8_t,
350+
int8_t,
351351
cutlass::bfloat16_t,
352352
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
353353
cute::Layout<_4>
@@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl<
362362
cute::Layout<_4>
363363
>& dst) {
364364

365-
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
365+
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
366366
cute::is_same_v<cutlass::bfloat16_t, typename EngineOut::value_type>);
367-
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
367+
using SrcArray = cutlass::Array<int8_t, 8>;
368368
using DstArray = cutlass::Array<cutlass::bfloat16_t, 8>;
369369
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;
370370

@@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl<
402402
// Specialization for INT8 -> FP16 with [3120] value order
403403
template <>
404404
struct LayoutAwareConvertImpl<
405-
cutlass::int8_t,
405+
int8_t,
406406
cutlass::half_t,
407407
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
408408
cute::Layout<_4>
@@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl<
417417
cute::Layout<_4>
418418
>& dst) {
419419

420-
static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
420+
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
421421
cute::is_same_v<cutlass::half_t, typename EngineOut::value_type>);
422-
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
422+
using SrcArray = cutlass::Array<int8_t, 8>;
423423
using DstArray = cutlass::Array<cutlass::half_t, 8>;
424424
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;
425425

include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,21 @@ namespace cutlass::gemm::collective {
5252
namespace detail {
5353

5454
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
55-
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
55+
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
5656
constexpr int
5757
compute_stage_count_or_override(StageCount<stages> stage_count) {
5858
return stages;
5959
}
6060

6161
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
62-
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
62+
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
6363
constexpr int
6464
compute_stage_count_or_override(cute::Int<stages> stage_count) {
6565
return stages;
6666
}
6767

6868
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
69-
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
69+
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
7070
constexpr int
7171
compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_count) {
7272
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
@@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
8585
}
8686

8787
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
88-
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
88+
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int alignment = 128, int carveout_bytes_>
8989
constexpr int
9090
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
9191
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
@@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
107107
}
108108

109109
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
110-
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
110+
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
111+
constexpr int
112+
compute_stage_count_or_override_single_affine_transformed_input(cute::Int<stages> stage_count) {
113+
return stages;
114+
}
115+
116+
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
117+
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
111118
constexpr int
112119
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
113120
return stages;
@@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() {
124131
}
125132

126133
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
127-
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
134+
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
128135
constexpr int
129136
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes_> stage_count) {
130137

@@ -456,12 +463,12 @@ public:
456463
static constexpr int PipelineStages = IsMixedInput ?
457464
( IsArrayOfPointersGemm ?
458465
detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes,
459-
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) :
466+
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) :
460467
detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
461-
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{})
468+
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{})
462469
)
463470
: detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
464-
ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{});
471+
ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{});
465472

466473
using DispatchPolicy = cute::conditional_t<IsMixedInput,
467474
cute::conditional_t<IsArrayOfPointersGemm,

tools/util/include/cutlass/util/mixed_dtype_utils.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "cutlass/util/device_memory.h"
4343
#include "cutlass/util/reference/device/tensor_fill.h"
4444
#include "cute/util/type_traits.hpp"
45+
#include "cute/numeric/numeric_types.hpp"
4546

4647
namespace cutlass {
4748

@@ -177,10 +178,7 @@ static void dequantize(DequantizedElement* dq_buffer,
177178
template <typename T>
178179
class packed_scale_t {
179180
public:
180-
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
181-
cute::is_same_v<T, cutlass::uint8_t> ||
182-
cute::is_same_v<T, cutlass::float_e4m3_t> ||
183-
cute::is_same_v<T, cutlass::float_e5m2_t>,
181+
static_assert(cute::sizeof_bits_v<T> == 8,
184182
"only 8 bit arithmetic types are supported.");
185183
CUTLASS_HOST_DEVICE
186184
explicit packed_scale_t(T val) {

0 commit comments

Comments
 (0)