11// Copyright 2026 Tencent
22// SPDX-License-Identifier: BSD-3-Clause
33
4+ #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
5+ void convolution_im2col_gemm_transform_kernel_bf16s_avx512bf16 (const Mat & kernel , Mat & AT , int inch , int outch , int kernel_w , int kernel_h , const Option & opt );
6+ int convolution_im2col_gemm_bf16s_avx512bf16 (const Mat & bottom_blob , Mat & top_blob , const Mat & AT , const Mat & bias , int kernel_w , int kernel_h , int dilation_w , int dilation_h , int stride_w , int stride_h , int activation_type , const Mat & activation_params , int nT , const Option & opt );
7+ #endif
8+
49static void convolution_im2col_pack_A_tile_bf16s (const Mat & A , Mat & AT , int i , int max_ii , int k , int max_kk )
510{
611 // A = (pa, maxk, inch/pa), outch
@@ -282,8 +287,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
282287 const float * p3 = (const float * )A + (i + ii + 3 ) * A_hstep + k ;
283288
284289 int kk = 0 ;
285- #if __AVX__
286290#if !__AVX512BF16__
291+ #if __AVX__
287292 for (; kk + 7 < max_kk ; kk += 8 )
288293 {
289294 __m256 _r0 = _mm256_loadu_ps (p0 );
@@ -359,8 +364,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
359364
360365 int kk = 0 ;
361366#if __SSE2__
362- #if __AVX__
363367#if !__AVX512BF16__
368+ #if __AVX__
364369 for (; kk + 7 < max_kk ; kk += 8 )
365370 {
366371 __m256 _r0 = _mm256_loadu_ps (p0 );
@@ -414,8 +419,8 @@ static void convolution_im2col_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, i
414419
415420 int kk = 0 ;
416421#if __SSE2__
417- #if __AVX__
418422#if !__AVX512BF16__
423+ #if __AVX__
419424 for (; kk + 7 < max_kk ; kk += 8 )
420425 {
421426 _mm_storeu_si128 ((__m128i * )pp , float2bfloat_avx (_mm256_loadu_ps (p0 )));
@@ -2549,7 +2554,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
25492554 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 5 ), float2bfloat_avx512 (_r5 ));
25502555 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 6 ), float2bfloat_avx512 (_r6 ));
25512556 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 7 ), float2bfloat_avx512 (_r7 ));
2552- #endif // ! __AVX512BF16__
2557+ #endif // __AVX512BF16__
25532558 pp += 128 ;
25542559 p0 += bottom_blob .cstep * 16 ;
25552560 }
@@ -2598,7 +2603,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
25982603 _mm_storeu_si128 ((__m128i * )(pp + 8 * 5 ), float2bfloat_avx (_r5 ));
25992604 _mm_storeu_si128 ((__m128i * )(pp + 8 * 6 ), float2bfloat_avx (_r6 ));
26002605 _mm_storeu_si128 ((__m128i * )(pp + 8 * 7 ), float2bfloat_avx (_r7 ));
2601- #endif // ! __AVX512BF16__
2606+ #endif // __AVX512BF16__
26022607 pp += 64 ;
26032608 p0 += bottom_blob .cstep * 8 ;
26042609 }
@@ -2640,7 +2645,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
26402645 _mm_storel_epi64 ((__m128i * )(pp + 4 * 5 ), float2bfloat_sse (_r6 ));
26412646 _mm_storel_epi64 ((__m128i * )(pp + 4 * 6 ), float2bfloat_sse (_r3 ));
26422647 _mm_storel_epi64 ((__m128i * )(pp + 4 * 7 ), float2bfloat_sse (_r7 ));
2643- #endif // ! __AVX512BF16__
2648+ #endif // __AVX512BF16__
26442649 pp += 32 ;
26452650 p0 += bottom_blob .cstep * 4 ;
26462651 }
@@ -2723,7 +2728,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27232728 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 1 ), float2bfloat_avx512 (_r1 ));
27242729 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 2 ), float2bfloat_avx512 (_r2 ));
27252730 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 3 ), float2bfloat_avx512 (_r3 ));
2726- #endif // ! __AVX512BF16__
2731+ #endif // __AVX512BF16__
27272732 pp += 64 ;
27282733 p0 += bottom_blob .cstep * 16 ;
27292734 }
@@ -2756,7 +2761,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27562761 _mm_storeu_si128 ((__m128i * )(pp + 8 * 1 ), float2bfloat_avx (_r1 ));
27572762 _mm_storeu_si128 ((__m128i * )(pp + 8 * 2 ), float2bfloat_avx (_r2 ));
27582763 _mm_storeu_si128 ((__m128i * )(pp + 8 * 3 ), float2bfloat_avx (_r3 ));
2759- #endif // ! __AVX512BF16__
2764+ #endif // __AVX512BF16__
27602765 pp += 32 ;
27612766 p0 += bottom_blob .cstep * 8 ;
27622767 }
@@ -2787,7 +2792,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27872792 _mm_storel_epi64 ((__m128i * )(pp + 4 * 1 ), float2bfloat_sse (_r1 ));
27882793 _mm_storel_epi64 ((__m128i * )(pp + 4 * 2 ), float2bfloat_sse (_r2 ));
27892794 _mm_storel_epi64 ((__m128i * )(pp + 4 * 3 ), float2bfloat_sse (_r3 ));
2790- #endif // ! __AVX512BF16__
2795+ #endif // __AVX512BF16__
27912796 pp += 16 ;
27922797 p0 += bottom_blob .cstep * 4 ;
27932798 }
@@ -2999,7 +3004,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
29993004 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 5 ), float2bfloat_avx512 (_r5 ));
30003005 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 6 ), float2bfloat_avx512 (_r6 ));
30013006 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 7 ), float2bfloat_avx512 (_r7 ));
3002- #endif // ! __AVX512BF16__
3007+ #endif // __AVX512BF16__
30033008 pp += 128 ;
30043009 }
30053010#endif // __AVX512F__
@@ -3041,7 +3046,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
30413046 _mm_storeu_si128 ((__m128i * )(pp + 8 * 5 ), float2bfloat_avx (_r5 ));
30423047 _mm_storeu_si128 ((__m128i * )(pp + 8 * 6 ), float2bfloat_avx (_r6 ));
30433048 _mm_storeu_si128 ((__m128i * )(pp + 8 * 7 ), float2bfloat_avx (_r7 ));
3044- #endif // ! __AVX512BF16__
3049+ #endif // __AVX512BF16__
30453050 pp += 64 ;
30463051 }
30473052#endif // __AVX__
@@ -3076,7 +3081,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
30763081 _mm_storel_epi64 ((__m128i * )(pp + 4 * 5 ), float2bfloat_sse (_r6 ));
30773082 _mm_storel_epi64 ((__m128i * )(pp + 4 * 6 ), float2bfloat_sse (_r3 ));
30783083 _mm_storel_epi64 ((__m128i * )(pp + 4 * 7 ), float2bfloat_sse (_r7 ));
3079- #endif // ! __AVX512BF16__
3084+ #endif // __AVX512BF16__
30803085 pp += 32 ;
30813086 }
30823087 if (elempack == 1 )
@@ -3195,7 +3200,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
31953200 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 5 ), float2bfloat_avx512 (_r5 ));
31963201 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 6 ), float2bfloat_avx512 (_r6 ));
31973202 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 7 ), float2bfloat_avx512 (_r7 ));
3198- #endif // ! __AVX512BF16__
3203+ #endif // __AVX512BF16__
31993204 pp += 128 ;
32003205 }
32013206#endif // __AVX512F__
@@ -3237,7 +3242,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
32373242 _mm_storeu_si128 ((__m128i * )(pp + 8 * 5 ), float2bfloat_avx (_r5 ));
32383243 _mm_storeu_si128 ((__m128i * )(pp + 8 * 6 ), float2bfloat_avx (_r6 ));
32393244 _mm_storeu_si128 ((__m128i * )(pp + 8 * 7 ), float2bfloat_avx (_r7 ));
3240- #endif // ! __AVX512BF16__
3245+ #endif // __AVX512BF16__
32413246 pp += 64 ;
32423247 }
32433248#endif // __AVX__
@@ -3272,7 +3277,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
32723277 _mm_storel_epi64 ((__m128i * )(pp + 4 * 5 ), float2bfloat_sse (_r6 ));
32733278 _mm_storel_epi64 ((__m128i * )(pp + 4 * 6 ), float2bfloat_sse (_r3 ));
32743279 _mm_storel_epi64 ((__m128i * )(pp + 4 * 7 ), float2bfloat_sse (_r7 ));
3275- #endif // ! __AVX512BF16__
3280+ #endif // __AVX512BF16__
32763281 pp += 32 ;
32773282 }
32783283 if (elempack == 1 )
@@ -3364,7 +3369,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33643369 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 1 ), float2bfloat_avx512 (_r1 ));
33653370 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 2 ), float2bfloat_avx512 (_r2 ));
33663371 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 3 ), float2bfloat_avx512 (_r3 ));
3367- #endif // ! __AVX512BF16__
3372+ #endif // __AVX512BF16__
33683373 pp += 64 ;
33693374 }
33703375#endif // __AVX512F__
@@ -3390,7 +3395,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33903395 _mm_storeu_si128 ((__m128i * )(pp + 8 * 1 ), float2bfloat_avx (_r1 ));
33913396 _mm_storeu_si128 ((__m128i * )(pp + 8 * 2 ), float2bfloat_avx (_r2 ));
33923397 _mm_storeu_si128 ((__m128i * )(pp + 8 * 3 ), float2bfloat_avx (_r3 ));
3393- #endif // ! __AVX512BF16__
3398+ #endif // __AVX512BF16__
33943399 pp += 32 ;
33953400 }
33963401#endif // __AVX__
@@ -3414,7 +3419,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34143419 _mm_storel_epi64 ((__m128i * )(pp + 4 * 1 ), float2bfloat_sse (_r1 ));
34153420 _mm_storel_epi64 ((__m128i * )(pp + 4 * 2 ), float2bfloat_sse (_r2 ));
34163421 _mm_storel_epi64 ((__m128i * )(pp + 4 * 3 ), float2bfloat_sse (_r3 ));
3417- #endif // ! __AVX512BF16__
3422+ #endif // __AVX512BF16__
34183423 pp += 16 ;
34193424 }
34203425 if (elempack == 1 )
@@ -3497,7 +3502,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34973502 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 1 ), float2bfloat_avx512 (_r1 ));
34983503 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 2 ), float2bfloat_avx512 (_r2 ));
34993504 _mm256_storeu_si256 ((__m256i * )(pp + 16 * 3 ), float2bfloat_avx512 (_r3 ));
3500- #endif // ! __AVX512BF16__
3505+ #endif // __AVX512BF16__
35013506 pp += 64 ;
35023507 }
35033508#endif // __AVX512F__
@@ -3523,7 +3528,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35233528 _mm_storeu_si128 ((__m128i * )(pp + 8 * 1 ), float2bfloat_avx (_r1 ));
35243529 _mm_storeu_si128 ((__m128i * )(pp + 8 * 2 ), float2bfloat_avx (_r2 ));
35253530 _mm_storeu_si128 ((__m128i * )(pp + 8 * 3 ), float2bfloat_avx (_r3 ));
3526- #endif // ! __AVX512BF16__
3531+ #endif // __AVX512BF16__
35273532 pp += 32 ;
35283533 }
35293534#endif // __AVX__
@@ -3547,7 +3552,7 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35473552 _mm_storel_epi64 ((__m128i * )(pp + 4 * 1 ), float2bfloat_sse (_r1 ));
35483553 _mm_storel_epi64 ((__m128i * )(pp + 4 * 2 ), float2bfloat_sse (_r2 ));
35493554 _mm_storel_epi64 ((__m128i * )(pp + 4 * 3 ), float2bfloat_sse (_r3 ));
3550- #endif // ! __AVX512BF16__
3555+ #endif // __AVX512BF16__
35513556 pp += 16 ;
35523557 }
35533558 if (elempack == 1 )
@@ -3732,6 +3737,14 @@ static void convolution_im2col_input_tile_bf16s(const Mat& bottom_blob, Mat& B,
37323737
37333738static void convolution_im2col_gemm_transform_kernel_bf16s (const Mat & kernel , Mat & AT , int inch , int outch , int kernel_w , int kernel_h , const Option & opt )
37343739{
3740+ #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
3741+ if (ncnn ::cpu_support_x86_avx512_bf16 ())
3742+ {
3743+ convolution_im2col_gemm_transform_kernel_bf16s_avx512bf16 (kernel , AT , inch , outch , kernel_w , kernel_h , opt );
3744+ return ;
3745+ }
3746+ #endif
3747+
37353748 // NCNN_LOGE("convolution_im2col_gemm_transform_kernel");
37363749 const int maxk = kernel_w * kernel_h ;
37373750
@@ -3810,6 +3823,13 @@ static void convolution_im2col_gemm_transform_kernel_bf16s(const Mat& kernel, Ma
38103823
38113824static int convolution_im2col_gemm_bf16s (const Mat & bottom_blob , Mat & top_blob , const Mat & AT , const Mat & bias , int kernel_w , int kernel_h , int dilation_w , int dilation_h , int stride_w , int stride_h , int activation_type , const Mat & activation_params , int nT , const Option & opt )
38123825{
3826+ #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
3827+ if (ncnn ::cpu_support_x86_avx512_bf16 ())
3828+ {
3829+ return convolution_im2col_gemm_bf16s_avx512bf16 (bottom_blob , top_blob , AT , bias , kernel_w , kernel_h , dilation_w , dilation_h , stride_w , stride_h , activation_type , activation_params , nT , opt );
3830+ }
3831+ #endif
3832+
38133833 const int maxk = kernel_w * kernel_h ;
38143834
38153835 const int M = top_blob .c * top_blob .elempack ;
0 commit comments