@@ -1356,15 +1356,11 @@ static void convolution_gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const
13561356#if __AVX512BF16__
13571357 for (; kk + 1 < max_kk ; kk += 2 )
13581358 {
1359- __m256i _pA0 = _mm256_loadu_si256 ((const __m256i * )pA );
1360- __m128i _pB = _mm_loadu_si128 ((const __m128i * )pB );
1361- __m256i _pB0 = combine4x2_epi32 (_pB , _pB );
1362- __m256i _pA1 = _mm256_alignr_epi8 (_pA0 , _pA0 , 8 );
1363- __m256i _pB1 = _mm256_alignr_epi8 (_pB0 , _pB0 , 4 );
1364- _sum0 = _mm256_dpbf16_ps (_sum0 , (__m256bh )_pA0 , (__m256bh )_pB0 );
1365- _sum1 = _mm256_dpbf16_ps (_sum1 , (__m256bh )_pA0 , (__m256bh )_pB1 );
1366- _sum2 = _mm256_dpbf16_ps (_sum2 , (__m256bh )_pA1 , (__m256bh )_pB0 );
1367- _sum3 = _mm256_dpbf16_ps (_sum3 , (__m256bh )_pA1 , (__m256bh )_pB1 );
1359+ __m256i _pA = _mm256_loadu_si256 ((const __m256i * )pA );
1360+ _sum0 = _mm256_dpbf16_ps (_sum0 , (__m256bh )_pA , (__m256bh )_mm256_set1_epi32 (((const int * )pB )[0 ]));
1361+ _sum1 = _mm256_dpbf16_ps (_sum1 , (__m256bh )_pA , (__m256bh )_mm256_set1_epi32 (((const int * )pB )[1 ]));
1362+ _sum2 = _mm256_dpbf16_ps (_sum2 , (__m256bh )_pA , (__m256bh )_mm256_set1_epi32 (((const int * )pB )[2 ]));
1363+ _sum3 = _mm256_dpbf16_ps (_sum3 , (__m256bh )_pA , (__m256bh )_mm256_set1_epi32 (((const int * )pB )[3 ]));
13681364 pA += 16 ;
13691365 pB += 8 ;
13701366 }
@@ -1716,14 +1712,11 @@ static void convolution_gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const
17161712#if __AVX512BF16__
17171713 for (; kk + 1 < max_kk ; kk += 2 )
17181714 {
1719- __m128i _pA0 = _mm_loadu_si128 ((const __m128i * )pA );
1720- __m128i _pB0 = _mm_loadu_si128 ((const __m128i * )pB );
1721- __m128i _pA1 = _mm_alignr_epi8 (_pA0 , _pA0 , 8 );
1722- __m128i _pB1 = _mm_alignr_epi8 (_pB0 , _pB0 , 4 );
1723- _sum0 = _mm_dpbf16_ps (_sum0 , (__m128bh )_pA0 , (__m128bh )_pB0 );
1724- _sum1 = _mm_dpbf16_ps (_sum1 , (__m128bh )_pA0 , (__m128bh )_pB1 );
1725- _sum2 = _mm_dpbf16_ps (_sum2 , (__m128bh )_pA1 , (__m128bh )_pB0 );
1726- _sum3 = _mm_dpbf16_ps (_sum3 , (__m128bh )_pA1 , (__m128bh )_pB1 );
1715+ __m128i _pA = _mm_loadu_si128 ((const __m128i * )pA );
1716+ _sum0 = _mm_dpbf16_ps (_sum0 , (__m128bh )_pA , (__m128bh )_mm_set1_epi32 (((const int * )pB )[0 ]));
1717+ _sum1 = _mm_dpbf16_ps (_sum1 , (__m128bh )_pA , (__m128bh )_mm_set1_epi32 (((const int * )pB )[1 ]));
1718+ _sum2 = _mm_dpbf16_ps (_sum2 , (__m128bh )_pA , (__m128bh )_mm_set1_epi32 (((const int * )pB )[2 ]));
1719+ _sum3 = _mm_dpbf16_ps (_sum3 , (__m128bh )_pA , (__m128bh )_mm_set1_epi32 (((const int * )pB )[3 ]));
17271720 pA += 8 ;
17281721 pB += 8 ;
17291722 }
@@ -2782,10 +2775,10 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27822775#if __AVX512BF16__
27832776 __m128i _t0 = float2bfloat_sse (_r0 );
27842777 __m128i _t1 = float2bfloat_sse (_r1 );
2785- _mm_storel_epi64 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
2778+ _mm_storeu_si128 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
27862779 __m128i _t2 = float2bfloat_sse (_r2 );
27872780 __m128i _t3 = float2bfloat_sse (_r3 );
2788- _mm_storel_epi64 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
2781+ _mm_storeu_si128 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
27892782#else
27902783
27912784 _mm_storel_epi64 ((__m128i * )pp , float2bfloat_sse (_r0 ));
@@ -2808,7 +2801,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
28082801 {
28092802 __m128i _r0 = _mm_loadl_epi64 ((const __m128i * )(p0 ));
28102803 __m128i _r1 = _mm_loadl_epi64 ((const __m128i * )(p0 + bottom_blob .cstep ));
2811- _mm_storel_epi64 ((__m128i * )pp , _mm_unpacklo_epi16 (_r0 , _r1 ));
2804+ _mm_storeu_si128 ((__m128i * )pp , _mm_unpacklo_epi16 (_r0 , _r1 ));
28122805 pp += 8 ;
28132806 p0 += bottom_blob .cstep * 2 ;
28142807 }
@@ -2926,6 +2919,49 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
29262919 if (dy0 == dy7 )
29272920 {
29282921 int kk = 0 ;
2922+ #if __AVX512BF16__
2923+ if (elempack == 1 )
2924+ {
2925+ for (; kk + 1 < max_kk ; kk += 2 )
2926+ {
2927+ int p0 = (k + kk ) / maxk ;
2928+ int uv0 = (k + kk ) % maxk ;
2929+ int u0 = uv0 / kernel_w ;
2930+ int v0 = uv0 % kernel_w ;
2931+ const Mat img0 = bottom_blob .channel (p0 );
2932+ int sx0 = stride_w * dx0 + dilation_w * v0 ;
2933+ int sy0 = stride_h * dy0 + dilation_h * u0 ;
2934+ const unsigned short * sptr0 = img0 .row < const unsigned short > (sy0 ) + sx0 ;
2935+
2936+ int p1 = (k + kk + 1 ) / maxk ;
2937+ int uv1 = (k + kk + 1 ) % maxk ;
2938+ int u1 = uv1 / kernel_w ;
2939+ int v1 = uv1 % kernel_w ;
2940+ const Mat img1 = bottom_blob .channel (p1 );
2941+ int sx1 = stride_w * dx0 + dilation_w * v1 ;
2942+ int sy1 = stride_h * dy0 + dilation_h * u1 ;
2943+ const unsigned short * sptr1 = img1 .row < const unsigned short > (sy1 ) + sx1 ;
2944+
2945+ pp [0 ] = sptr0 [0 ];
2946+ pp [1 ] = sptr1 [0 ];
2947+ pp [2 ] = sptr0 [stride_w ];
2948+ pp [3 ] = sptr1 [stride_w ];
2949+ pp [4 ] = sptr0 [stride_w * 2 ];
2950+ pp [5 ] = sptr1 [stride_w * 2 ];
2951+ pp [6 ] = sptr0 [stride_w * 3 ];
2952+ pp [7 ] = sptr1 [stride_w * 3 ];
2953+ pp [8 ] = sptr0 [stride_w * 4 ];
2954+ pp [9 ] = sptr1 [stride_w * 4 ];
2955+ pp [10 ] = sptr0 [stride_w * 5 ];
2956+ pp [11 ] = sptr1 [stride_w * 5 ];
2957+ pp [12 ] = sptr0 [stride_w * 6 ];
2958+ pp [13 ] = sptr1 [stride_w * 6 ];
2959+ pp [14 ] = sptr0 [stride_w * 7 ];
2960+ pp [15 ] = sptr1 [stride_w * 7 ];
2961+ pp += 16 ;
2962+ }
2963+ }
2964+ #endif // __AVX512BF16__
29292965 for (; kk < max_kk / elempack ; kk ++ )
29302966 {
29312967 int p = (k / elempack + kk ) / maxk ;
@@ -3101,6 +3137,43 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
31013137 else
31023138 {
31033139 int kk = 0 ;
3140+ #if __AVX512BF16__
3141+ if (elempack == 1 )
3142+ {
3143+ for (; kk + 1 < max_kk ; kk += 2 )
3144+ {
3145+ int p0 = (k + kk ) / maxk ;
3146+ int uv0 = (k + kk ) % maxk ;
3147+ int u0 = uv0 / kernel_w ;
3148+ int v0 = uv0 % kernel_w ;
3149+ const Mat img0 = bottom_blob .channel (p0 );
3150+
3151+ int p1 = (k + kk + 1 ) / maxk ;
3152+ int uv1 = (k + kk + 1 ) % maxk ;
3153+ int u1 = uv1 / kernel_w ;
3154+ int v1 = uv1 % kernel_w ;
3155+ const Mat img1 = bottom_blob .channel (p1 );
3156+
3157+ pp [0 ] = img0 .row < const unsigned short> (stride_h * dy0 + dilation_h * u0 )[stride_w * dx0 + dilation_w * v0 ];
3158+ pp [1 ] = img1 .row < const unsigned short> (stride_h * dy0 + dilation_h * u1 )[stride_w * dx0 + dilation_w * v1 ];
3159+ pp [2 ] = img0 .row < const unsigned short> (stride_h * dy1 + dilation_h * u0 )[stride_w * dx1 + dilation_w * v0 ];
3160+ pp [3 ] = img1 .row < const unsigned short> (stride_h * dy1 + dilation_h * u1 )[stride_w * dx1 + dilation_w * v1 ];
3161+ pp [4 ] = img0 .row < const unsigned short> (stride_h * dy2 + dilation_h * u0 )[stride_w * dx2 + dilation_w * v0 ];
3162+ pp [5 ] = img1 .row < const unsigned short> (stride_h * dy2 + dilation_h * u1 )[stride_w * dx2 + dilation_w * v1 ];
3163+ pp [6 ] = img0 .row < const unsigned short> (stride_h * dy3 + dilation_h * u0 )[stride_w * dx3 + dilation_w * v0 ];
3164+ pp [7 ] = img1 .row < const unsigned short> (stride_h * dy3 + dilation_h * u1 )[stride_w * dx3 + dilation_w * v1 ];
3165+ pp [8 ] = img0 .row < const unsigned short> (stride_h * dy4 + dilation_h * u0 )[stride_w * dx4 + dilation_w * v0 ];
3166+ pp [9 ] = img1 .row < const unsigned short> (stride_h * dy4 + dilation_h * u1 )[stride_w * dx4 + dilation_w * v1 ];
3167+ pp [10 ] = img0 .row < const unsigned short> (stride_h * dy5 + dilation_h * u0 )[stride_w * dx5 + dilation_w * v0 ];
3168+ pp [11 ] = img1 .row < const unsigned short> (stride_h * dy5 + dilation_h * u1 )[stride_w * dx5 + dilation_w * v1 ];
3169+ pp [12 ] = img0 .row < const unsigned short> (stride_h * dy6 + dilation_h * u0 )[stride_w * dx6 + dilation_w * v0 ];
3170+ pp [13 ] = img1 .row < const unsigned short> (stride_h * dy6 + dilation_h * u1 )[stride_w * dx6 + dilation_w * v1 ];
3171+ pp [14 ] = img0 .row < const unsigned short> (stride_h * dy7 + dilation_h * u0 )[stride_w * dx7 + dilation_w * v0 ];
3172+ pp [15 ] = img1 .row < const unsigned short> (stride_h * dy7 + dilation_h * u1 )[stride_w * dx7 + dilation_w * v1 ];
3173+ pp += 16 ;
3174+ }
3175+ }
3176+ #endif // __AVX512BF16__
31043177 for (; kk < max_kk / elempack ; kk ++ )
31053178 {
31063179 int p = (k / elempack + kk ) / maxk ;
@@ -3311,6 +3384,41 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33113384 if (dy0 == dy3 )
33123385 {
33133386 int kk = 0 ;
3387+ #if __AVX512BF16__
3388+ if (elempack == 1 )
3389+ {
3390+ for (; kk + 1 < max_kk ; kk += 2 )
3391+ {
3392+ int p0 = (k + kk ) / maxk ;
3393+ int uv0 = (k + kk ) % maxk ;
3394+ int u0 = uv0 / kernel_w ;
3395+ int v0 = uv0 % kernel_w ;
3396+ const Mat img0 = bottom_blob .channel (p0 );
3397+ int sx0 = stride_w * dx0 + dilation_w * v0 ;
3398+ int sy0 = stride_h * dy0 + dilation_h * u0 ;
3399+ const unsigned short * sptr0 = img0 .row < const unsigned short > (sy0 ) + sx0 ;
3400+
3401+ int p1 = (k + kk + 1 ) / maxk ;
3402+ int uv1 = (k + kk + 1 ) % maxk ;
3403+ int u1 = uv1 / kernel_w ;
3404+ int v1 = uv1 % kernel_w ;
3405+ const Mat img1 = bottom_blob .channel (p1 );
3406+ int sx1 = stride_w * dx0 + dilation_w * v1 ;
3407+ int sy1 = stride_h * dy0 + dilation_h * u1 ;
3408+ const unsigned short * sptr1 = img1 .row < const unsigned short > (sy1 ) + sx1 ;
3409+
3410+ pp [0 ] = sptr0 [0 ];
3411+ pp [1 ] = sptr1 [0 ];
3412+ pp [2 ] = sptr0 [stride_w ];
3413+ pp [3 ] = sptr1 [stride_w ];
3414+ pp [4 ] = sptr0 [stride_w * 2 ];
3415+ pp [5 ] = sptr1 [stride_w * 2 ];
3416+ pp [6 ] = sptr0 [stride_w * 3 ];
3417+ pp [7 ] = sptr1 [stride_w * 3 ];
3418+ pp += 8 ;
3419+ }
3420+ }
3421+ #endif // __AVX512BF16__
33143422 for (; kk < max_kk / elempack ; kk ++ )
33153423 {
33163424 int p = (k / elempack + kk ) / maxk ;
@@ -3409,10 +3517,10 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34093517#if __AVX512BF16__
34103518 __m128i _t0 = float2bfloat_sse (_r0 );
34113519 __m128i _t1 = float2bfloat_sse (_r1 );
3412- _mm_storel_epi64 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
3520+ _mm_storeu_si128 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
34133521 __m128i _t2 = float2bfloat_sse (_r2 );
34143522 __m128i _t3 = float2bfloat_sse (_r3 );
3415- _mm_storel_epi64 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
3523+ _mm_storeu_si128 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
34163524#else
34173525
34183526 _mm_storel_epi64 ((__m128i * )pp , float2bfloat_sse (_r0 ));
@@ -3435,6 +3543,35 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34353543 else
34363544 {
34373545 int kk = 0 ;
3546+ #if __AVX512BF16__
3547+ if (elempack == 1 )
3548+ {
3549+ for (; kk + 1 < max_kk ; kk += 2 )
3550+ {
3551+ int p0 = (k + kk ) / maxk ;
3552+ int uv0 = (k + kk ) % maxk ;
3553+ int u0 = uv0 / kernel_w ;
3554+ int v0 = uv0 % kernel_w ;
3555+ const Mat img0 = bottom_blob .channel (p0 );
3556+
3557+ int p1 = (k + kk + 1 ) / maxk ;
3558+ int uv1 = (k + kk + 1 ) % maxk ;
3559+ int u1 = uv1 / kernel_w ;
3560+ int v1 = uv1 % kernel_w ;
3561+ const Mat img1 = bottom_blob .channel (p1 );
3562+
3563+ pp [0 ] = img0 .row < const unsigned short> (stride_h * dy0 + dilation_h * u0 )[stride_w * dx0 + dilation_w * v0 ];
3564+ pp [1 ] = img1 .row < const unsigned short> (stride_h * dy0 + dilation_h * u1 )[stride_w * dx0 + dilation_w * v1 ];
3565+ pp [2 ] = img0 .row < const unsigned short> (stride_h * dy1 + dilation_h * u0 )[stride_w * dx1 + dilation_w * v0 ];
3566+ pp [3 ] = img1 .row < const unsigned short> (stride_h * dy1 + dilation_h * u1 )[stride_w * dx1 + dilation_w * v1 ];
3567+ pp [4 ] = img0 .row < const unsigned short> (stride_h * dy2 + dilation_h * u0 )[stride_w * dx2 + dilation_w * v0 ];
3568+ pp [5 ] = img1 .row < const unsigned short> (stride_h * dy2 + dilation_h * u1 )[stride_w * dx2 + dilation_w * v1 ];
3569+ pp [6 ] = img0 .row < const unsigned short> (stride_h * dy3 + dilation_h * u0 )[stride_w * dx3 + dilation_w * v0 ];
3570+ pp [7 ] = img1 .row < const unsigned short> (stride_h * dy3 + dilation_h * u1 )[stride_w * dx3 + dilation_w * v1 ];
3571+ pp += 8 ;
3572+ }
3573+ }
3574+ #endif // __AVX512BF16__
34383575 for (; kk < max_kk / elempack ; kk ++ )
34393576 {
34403577 int p = (k / elempack + kk ) / maxk ;
@@ -3542,10 +3679,10 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35423679#if __AVX512BF16__
35433680 __m128i _t0 = float2bfloat_sse (_r0 );
35443681 __m128i _t1 = float2bfloat_sse (_r1 );
3545- _mm_storel_epi64 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
3682+ _mm_storeu_si128 ((__m128i * )(pp + 0 ), _mm_unpacklo_epi16 (_t0 , _t1 ));
35463683 __m128i _t2 = float2bfloat_sse (_r2 );
35473684 __m128i _t3 = float2bfloat_sse (_r3 );
3548- _mm_storel_epi64 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
3685+ _mm_storeu_si128 ((__m128i * )(pp + 8 ), _mm_unpacklo_epi16 (_t2 , _t3 ));
35493686#else
35503687
35513688 _mm_storel_epi64 ((__m128i * )pp , float2bfloat_sse (_r0 ));
0 commit comments