Skip to content

Commit 697d5ff

Browse files
committed
dispatch avx512bf16
1 parent f2f1626 commit 697d5ff

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

src/layer/x86/convolution_im2col_gemm_bf16s.h

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// Copyright 2026 Tencent
22
// SPDX-License-Identifier: BSD-3-Clause
33

4+
#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
5+
void convolution_im2col_gemm_transform_kernel_bf16s_avx512bf16(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt);
6+
int convolution_im2col_gemm_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, int nT, const Option& opt);
7+
#endif
8+
49
static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
510
{
611
// A = (pa, maxk, inch/pa), outch
@@ -282,8 +287,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
282287
const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k;
283288

284289
int kk = 0;
285-
#if __AVX__
286290
#if !__AVX512BF16__
291+
#if __AVX__
287292
for (; kk + 7 < max_kk; kk += 8)
288293
{
289294
__m256 _r0 = _mm256_loadu_ps(p0);
@@ -359,8 +364,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
359364

360365
int kk = 0;
361366
#if __SSE2__
362-
#if __AVX__
363367
#if !__AVX512BF16__
368+
#if __AVX__
364369
for (; kk + 7 < max_kk; kk += 8)
365370
{
366371
__m256 _r0 = _mm256_loadu_ps(p0);
@@ -414,8 +419,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
414419

415420
int kk = 0;
416421
#if __SSE2__
417-
#if __AVX__
418422
#if !__AVX512BF16__
423+
#if __AVX__
419424
for (; kk + 7 < max_kk; kk += 8)
420425
{
421426
_mm_storeu_si128((__m128i*)pp, float2bfloat_avx(_mm256_loadu_ps(p0)));
@@ -2549,7 +2554,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
25492554
_mm256_storeu_si256((__m256i*)(pp + 16 * 5), float2bfloat_avx512(_r5));
25502555
_mm256_storeu_si256((__m256i*)(pp + 16 * 6), float2bfloat_avx512(_r6));
25512556
_mm256_storeu_si256((__m256i*)(pp + 16 * 7), float2bfloat_avx512(_r7));
2552-
#endif // !__AVX512BF16__
2557+
#endif // __AVX512BF16__
25532558
pp += 128;
25542559
p0 += bottom_blob.cstep * 16;
25552560
}
@@ -2598,7 +2603,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
25982603
_mm_storeu_si128((__m128i*)(pp + 8 * 5), float2bfloat_avx(_r5));
25992604
_mm_storeu_si128((__m128i*)(pp + 8 * 6), float2bfloat_avx(_r6));
26002605
_mm_storeu_si128((__m128i*)(pp + 8 * 7), float2bfloat_avx(_r7));
2601-
#endif // !__AVX512BF16__
2606+
#endif // __AVX512BF16__
26022607
pp += 64;
26032608
p0 += bottom_blob.cstep * 8;
26042609
}
@@ -2640,7 +2645,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
26402645
_mm_storel_epi64((__m128i*)(pp + 4 * 5), float2bfloat_sse(_r6));
26412646
_mm_storel_epi64((__m128i*)(pp + 4 * 6), float2bfloat_sse(_r3));
26422647
_mm_storel_epi64((__m128i*)(pp + 4 * 7), float2bfloat_sse(_r7));
2643-
#endif // !__AVX512BF16__
2648+
#endif // __AVX512BF16__
26442649
pp += 32;
26452650
p0 += bottom_blob.cstep * 4;
26462651
}
@@ -2723,7 +2728,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27232728
_mm256_storeu_si256((__m256i*)(pp + 16 * 1), float2bfloat_avx512(_r1));
27242729
_mm256_storeu_si256((__m256i*)(pp + 16 * 2), float2bfloat_avx512(_r2));
27252730
_mm256_storeu_si256((__m256i*)(pp + 16 * 3), float2bfloat_avx512(_r3));
2726-
#endif // !__AVX512BF16__
2731+
#endif // __AVX512BF16__
27272732
pp += 64;
27282733
p0 += bottom_blob.cstep * 16;
27292734
}
@@ -2756,7 +2761,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27562761
_mm_storeu_si128((__m128i*)(pp + 8 * 1), float2bfloat_avx(_r1));
27572762
_mm_storeu_si128((__m128i*)(pp + 8 * 2), float2bfloat_avx(_r2));
27582763
_mm_storeu_si128((__m128i*)(pp + 8 * 3), float2bfloat_avx(_r3));
2759-
#endif // !__AVX512BF16__
2764+
#endif // __AVX512BF16__
27602765
pp += 32;
27612766
p0 += bottom_blob.cstep * 8;
27622767
}
@@ -2787,7 +2792,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27872792
_mm_storel_epi64((__m128i*)(pp + 4 * 1), float2bfloat_sse(_r1));
27882793
_mm_storel_epi64((__m128i*)(pp + 4 * 2), float2bfloat_sse(_r2));
27892794
_mm_storel_epi64((__m128i*)(pp + 4 * 3), float2bfloat_sse(_r3));
2790-
#endif // !__AVX512BF16__
2795+
#endif // __AVX512BF16__
27912796
pp += 16;
27922797
p0 += bottom_blob.cstep * 4;
27932798
}
@@ -2999,7 +3004,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
29993004
_mm256_storeu_si256((__m256i*)(pp + 16 * 5), float2bfloat_avx512(_r5));
30003005
_mm256_storeu_si256((__m256i*)(pp + 16 * 6), float2bfloat_avx512(_r6));
30013006
_mm256_storeu_si256((__m256i*)(pp + 16 * 7), float2bfloat_avx512(_r7));
3002-
#endif // !__AVX512BF16__
3007+
#endif // __AVX512BF16__
30033008
pp += 128;
30043009
}
30053010
#endif // __AVX512F__
@@ -3041,7 +3046,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
30413046
_mm_storeu_si128((__m128i*)(pp + 8 * 5), float2bfloat_avx(_r5));
30423047
_mm_storeu_si128((__m128i*)(pp + 8 * 6), float2bfloat_avx(_r6));
30433048
_mm_storeu_si128((__m128i*)(pp + 8 * 7), float2bfloat_avx(_r7));
3044-
#endif // !__AVX512BF16__
3049+
#endif // __AVX512BF16__
30453050
pp += 64;
30463051
}
30473052
#endif // __AVX__
@@ -3076,7 +3081,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
30763081
_mm_storel_epi64((__m128i*)(pp + 4 * 5), float2bfloat_sse(_r6));
30773082
_mm_storel_epi64((__m128i*)(pp + 4 * 6), float2bfloat_sse(_r3));
30783083
_mm_storel_epi64((__m128i*)(pp + 4 * 7), float2bfloat_sse(_r7));
3079-
#endif // !__AVX512BF16__
3084+
#endif // __AVX512BF16__
30803085
pp += 32;
30813086
}
30823087
if (elempack == 1)
@@ -3195,7 +3200,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
31953200
_mm256_storeu_si256((__m256i*)(pp + 16 * 5), float2bfloat_avx512(_r5));
31963201
_mm256_storeu_si256((__m256i*)(pp + 16 * 6), float2bfloat_avx512(_r6));
31973202
_mm256_storeu_si256((__m256i*)(pp + 16 * 7), float2bfloat_avx512(_r7));
3198-
#endif // !__AVX512BF16__
3203+
#endif // __AVX512BF16__
31993204
pp += 128;
32003205
}
32013206
#endif // __AVX512F__
@@ -3237,7 +3242,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
32373242
_mm_storeu_si128((__m128i*)(pp + 8 * 5), float2bfloat_avx(_r5));
32383243
_mm_storeu_si128((__m128i*)(pp + 8 * 6), float2bfloat_avx(_r6));
32393244
_mm_storeu_si128((__m128i*)(pp + 8 * 7), float2bfloat_avx(_r7));
3240-
#endif // !__AVX512BF16__
3245+
#endif // __AVX512BF16__
32413246
pp += 64;
32423247
}
32433248
#endif // __AVX__
@@ -3272,7 +3277,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
32723277
_mm_storel_epi64((__m128i*)(pp + 4 * 5), float2bfloat_sse(_r6));
32733278
_mm_storel_epi64((__m128i*)(pp + 4 * 6), float2bfloat_sse(_r3));
32743279
_mm_storel_epi64((__m128i*)(pp + 4 * 7), float2bfloat_sse(_r7));
3275-
#endif // !__AVX512BF16__
3280+
#endif // __AVX512BF16__
32763281
pp += 32;
32773282
}
32783283
if (elempack == 1)
@@ -3364,7 +3369,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33643369
_mm256_storeu_si256((__m256i*)(pp + 16 * 1), float2bfloat_avx512(_r1));
33653370
_mm256_storeu_si256((__m256i*)(pp + 16 * 2), float2bfloat_avx512(_r2));
33663371
_mm256_storeu_si256((__m256i*)(pp + 16 * 3), float2bfloat_avx512(_r3));
3367-
#endif // !__AVX512BF16__
3372+
#endif // __AVX512BF16__
33683373
pp += 64;
33693374
}
33703375
#endif // __AVX512F__
@@ -3390,7 +3395,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33903395
_mm_storeu_si128((__m128i*)(pp + 8 * 1), float2bfloat_avx(_r1));
33913396
_mm_storeu_si128((__m128i*)(pp + 8 * 2), float2bfloat_avx(_r2));
33923397
_mm_storeu_si128((__m128i*)(pp + 8 * 3), float2bfloat_avx(_r3));
3393-
#endif // !__AVX512BF16__
3398+
#endif // __AVX512BF16__
33943399
pp += 32;
33953400
}
33963401
#endif // __AVX__
@@ -3414,7 +3419,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34143419
_mm_storel_epi64((__m128i*)(pp + 4 * 1), float2bfloat_sse(_r1));
34153420
_mm_storel_epi64((__m128i*)(pp + 4 * 2), float2bfloat_sse(_r2));
34163421
_mm_storel_epi64((__m128i*)(pp + 4 * 3), float2bfloat_sse(_r3));
3417-
#endif // !__AVX512BF16__
3422+
#endif // __AVX512BF16__
34183423
pp += 16;
34193424
}
34203425
if (elempack == 1)
@@ -3497,7 +3502,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34973502
_mm256_storeu_si256((__m256i*)(pp + 16 * 1), float2bfloat_avx512(_r1));
34983503
_mm256_storeu_si256((__m256i*)(pp + 16 * 2), float2bfloat_avx512(_r2));
34993504
_mm256_storeu_si256((__m256i*)(pp + 16 * 3), float2bfloat_avx512(_r3));
3500-
#endif // !__AVX512BF16__
3505+
#endif // __AVX512BF16__
35013506
pp += 64;
35023507
}
35033508
#endif // __AVX512F__
@@ -3523,7 +3528,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35233528
_mm_storeu_si128((__m128i*)(pp + 8 * 1), float2bfloat_avx(_r1));
35243529
_mm_storeu_si128((__m128i*)(pp + 8 * 2), float2bfloat_avx(_r2));
35253530
_mm_storeu_si128((__m128i*)(pp + 8 * 3), float2bfloat_avx(_r3));
3526-
#endif // !__AVX512BF16__
3531+
#endif // __AVX512BF16__
35273532
pp += 32;
35283533
}
35293534
#endif // __AVX__
@@ -3547,7 +3552,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35473552
_mm_storel_epi64((__m128i*)(pp + 4 * 1), float2bfloat_sse(_r1));
35483553
_mm_storel_epi64((__m128i*)(pp + 4 * 2), float2bfloat_sse(_r2));
35493554
_mm_storel_epi64((__m128i*)(pp + 4 * 3), float2bfloat_sse(_r3));
3550-
#endif // !__AVX512BF16__
3555+
#endif // __AVX512BF16__
35513556
pp += 16;
35523557
}
35533558
if (elempack == 1)
@@ -3732,6 +3737,14 @@ static void convolution_im2col_input_tile_bf16s(const Mat& bottom_blob, Mat& B,
37323737

37333738
static void convolution_im2col_gemm_transform_kernel_bf16s(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt)
37343739
{
3740+
#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
3741+
if (ncnn::cpu_support_x86_avx512_bf16())
3742+
{
3743+
convolution_im2col_gemm_transform_kernel_bf16s_avx512bf16(kernel, AT, inch, outch, kernel_w, kernel_h, opt);
3744+
return;
3745+
}
3746+
#endif
3747+
37353748
// NCNN_LOGE("convolution_im2col_gemm_transform_kernel");
37363749
const int maxk = kernel_w * kernel_h;
37373750

@@ -3810,6 +3823,13 @@ static void convolution_im2col_gemm_transform_kernel_bf16s(const Mat& kernel, Ma
38103823

38113824
static int convolution_im2col_gemm_bf16s(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, int nT, const Option& opt)
38123825
{
3826+
#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
3827+
if (ncnn::cpu_support_x86_avx512_bf16())
3828+
{
3829+
return convolution_im2col_gemm_bf16s_avx512bf16(bottom_blob, top_blob, AT, bias, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, nT, opt);
3830+
}
3831+
#endif
3832+
38133833
const int maxk = kernel_w * kernel_h;
38143834

38153835
const int M = top_blob.c * top_blob.elempack;

0 commit comments

Comments
 (0)