Skip to content

Commit 031a7b9

Browse files
committed
f
1 parent 697d5ff commit 031a7b9

File tree

1 file changed

+161
-24
lines changed

1 file changed

+161
-24
lines changed

src/layer/x86/convolution_im2col_gemm_bf16s.h

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,15 +1356,11 @@ static void convolution_gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const
13561356
#if __AVX512BF16__
13571357
for (; kk + 1 < max_kk; kk += 2)
13581358
{
1359-
__m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA);
1360-
__m128i _pB = _mm_loadu_si128((const __m128i*)pB);
1361-
__m256i _pB0 = combine4x2_epi32(_pB, _pB);
1362-
__m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8);
1363-
__m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4);
1364-
_sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA0, (__m256bh)_pB0);
1365-
_sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA0, (__m256bh)_pB1);
1366-
_sum2 = _mm256_dpbf16_ps(_sum2, (__m256bh)_pA1, (__m256bh)_pB0);
1367-
_sum3 = _mm256_dpbf16_ps(_sum3, (__m256bh)_pA1, (__m256bh)_pB1);
1359+
__m256i _pA = _mm256_loadu_si256((const __m256i*)pA);
1360+
_sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA, (__m256bh)_mm256_set1_epi32(((const int*)pB)[0]));
1361+
_sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA, (__m256bh)_mm256_set1_epi32(((const int*)pB)[1]));
1362+
_sum2 = _mm256_dpbf16_ps(_sum2, (__m256bh)_pA, (__m256bh)_mm256_set1_epi32(((const int*)pB)[2]));
1363+
_sum3 = _mm256_dpbf16_ps(_sum3, (__m256bh)_pA, (__m256bh)_mm256_set1_epi32(((const int*)pB)[3]));
13681364
pA += 16;
13691365
pB += 8;
13701366
}
@@ -1716,14 +1712,11 @@ static void convolution_gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const
17161712
#if __AVX512BF16__
17171713
for (; kk + 1 < max_kk; kk += 2)
17181714
{
1719-
__m128i _pA0 = _mm_loadu_si128((const __m128i*)pA);
1720-
__m128i _pB0 = _mm_loadu_si128((const __m128i*)pB);
1721-
__m128i _pA1 = _mm_alignr_epi8(_pA0, _pA0, 8);
1722-
__m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4);
1723-
_sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA0, (__m128bh)_pB0);
1724-
_sum1 = _mm_dpbf16_ps(_sum1, (__m128bh)_pA0, (__m128bh)_pB1);
1725-
_sum2 = _mm_dpbf16_ps(_sum2, (__m128bh)_pA1, (__m128bh)_pB0);
1726-
_sum3 = _mm_dpbf16_ps(_sum3, (__m128bh)_pA1, (__m128bh)_pB1);
1715+
__m128i _pA = _mm_loadu_si128((const __m128i*)pA);
1716+
_sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA, (__m128bh)_mm_set1_epi32(((const int*)pB)[0]));
1717+
_sum1 = _mm_dpbf16_ps(_sum1, (__m128bh)_pA, (__m128bh)_mm_set1_epi32(((const int*)pB)[1]));
1718+
_sum2 = _mm_dpbf16_ps(_sum2, (__m128bh)_pA, (__m128bh)_mm_set1_epi32(((const int*)pB)[2]));
1719+
_sum3 = _mm_dpbf16_ps(_sum3, (__m128bh)_pA, (__m128bh)_mm_set1_epi32(((const int*)pB)[3]));
17271720
pA += 8;
17281721
pB += 8;
17291722
}
@@ -2782,10 +2775,10 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
27822775
#if __AVX512BF16__
27832776
__m128i _t0 = float2bfloat_sse(_r0);
27842777
__m128i _t1 = float2bfloat_sse(_r1);
2785-
_mm_storel_epi64((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
2778+
_mm_storeu_si128((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
27862779
__m128i _t2 = float2bfloat_sse(_r2);
27872780
__m128i _t3 = float2bfloat_sse(_r3);
2788-
_mm_storel_epi64((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
2781+
_mm_storeu_si128((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
27892782
#else
27902783

27912784
_mm_storel_epi64((__m128i*)pp, float2bfloat_sse(_r0));
@@ -2808,7 +2801,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_bf16s(const Mat& bottom_bl
28082801
{
28092802
__m128i _r0 = _mm_loadl_epi64((const __m128i*)(p0));
28102803
__m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + bottom_blob.cstep));
2811-
_mm_storel_epi64((__m128i*)pp, _mm_unpacklo_epi16(_r0, _r1));
2804+
_mm_storeu_si128((__m128i*)pp, _mm_unpacklo_epi16(_r0, _r1));
28122805
pp += 8;
28132806
p0 += bottom_blob.cstep * 2;
28142807
}
@@ -2926,6 +2919,49 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
29262919
if (dy0 == dy7)
29272920
{
29282921
int kk = 0;
2922+
#if __AVX512BF16__
2923+
if (elempack == 1)
2924+
{
2925+
for (; kk + 1 < max_kk; kk += 2)
2926+
{
2927+
int p0 = (k + kk) / maxk;
2928+
int uv0 = (k + kk) % maxk;
2929+
int u0 = uv0 / kernel_w;
2930+
int v0 = uv0 % kernel_w;
2931+
const Mat img0 = bottom_blob.channel(p0);
2932+
int sx0 = stride_w * dx0 + dilation_w * v0;
2933+
int sy0 = stride_h * dy0 + dilation_h * u0;
2934+
const unsigned short* sptr0 = img0.row<const unsigned short>(sy0) + sx0;
2935+
2936+
int p1 = (k + kk + 1) / maxk;
2937+
int uv1 = (k + kk + 1) % maxk;
2938+
int u1 = uv1 / kernel_w;
2939+
int v1 = uv1 % kernel_w;
2940+
const Mat img1 = bottom_blob.channel(p1);
2941+
int sx1 = stride_w * dx0 + dilation_w * v1;
2942+
int sy1 = stride_h * dy0 + dilation_h * u1;
2943+
const unsigned short* sptr1 = img1.row<const unsigned short>(sy1) + sx1;
2944+
2945+
pp[0] = sptr0[0];
2946+
pp[1] = sptr1[0];
2947+
pp[2] = sptr0[stride_w];
2948+
pp[3] = sptr1[stride_w];
2949+
pp[4] = sptr0[stride_w * 2];
2950+
pp[5] = sptr1[stride_w * 2];
2951+
pp[6] = sptr0[stride_w * 3];
2952+
pp[7] = sptr1[stride_w * 3];
2953+
pp[8] = sptr0[stride_w * 4];
2954+
pp[9] = sptr1[stride_w * 4];
2955+
pp[10] = sptr0[stride_w * 5];
2956+
pp[11] = sptr1[stride_w * 5];
2957+
pp[12] = sptr0[stride_w * 6];
2958+
pp[13] = sptr1[stride_w * 6];
2959+
pp[14] = sptr0[stride_w * 7];
2960+
pp[15] = sptr1[stride_w * 7];
2961+
pp += 16;
2962+
}
2963+
}
2964+
#endif // __AVX512BF16__
29292965
for (; kk < max_kk / elempack; kk++)
29302966
{
29312967
int p = (k / elempack + kk) / maxk;
@@ -3101,6 +3137,43 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
31013137
else
31023138
{
31033139
int kk = 0;
3140+
#if __AVX512BF16__
3141+
if (elempack == 1)
3142+
{
3143+
for (; kk + 1 < max_kk; kk += 2)
3144+
{
3145+
int p0 = (k + kk) / maxk;
3146+
int uv0 = (k + kk) % maxk;
3147+
int u0 = uv0 / kernel_w;
3148+
int v0 = uv0 % kernel_w;
3149+
const Mat img0 = bottom_blob.channel(p0);
3150+
3151+
int p1 = (k + kk + 1) / maxk;
3152+
int uv1 = (k + kk + 1) % maxk;
3153+
int u1 = uv1 / kernel_w;
3154+
int v1 = uv1 % kernel_w;
3155+
const Mat img1 = bottom_blob.channel(p1);
3156+
3157+
pp[0] = img0.row<const unsigned short>(stride_h * dy0 + dilation_h * u0)[stride_w * dx0 + dilation_w * v0];
3158+
pp[1] = img1.row<const unsigned short>(stride_h * dy0 + dilation_h * u1)[stride_w * dx0 + dilation_w * v1];
3159+
pp[2] = img0.row<const unsigned short>(stride_h * dy1 + dilation_h * u0)[stride_w * dx1 + dilation_w * v0];
3160+
pp[3] = img1.row<const unsigned short>(stride_h * dy1 + dilation_h * u1)[stride_w * dx1 + dilation_w * v1];
3161+
pp[4] = img0.row<const unsigned short>(stride_h * dy2 + dilation_h * u0)[stride_w * dx2 + dilation_w * v0];
3162+
pp[5] = img1.row<const unsigned short>(stride_h * dy2 + dilation_h * u1)[stride_w * dx2 + dilation_w * v1];
3163+
pp[6] = img0.row<const unsigned short>(stride_h * dy3 + dilation_h * u0)[stride_w * dx3 + dilation_w * v0];
3164+
pp[7] = img1.row<const unsigned short>(stride_h * dy3 + dilation_h * u1)[stride_w * dx3 + dilation_w * v1];
3165+
pp[8] = img0.row<const unsigned short>(stride_h * dy4 + dilation_h * u0)[stride_w * dx4 + dilation_w * v0];
3166+
pp[9] = img1.row<const unsigned short>(stride_h * dy4 + dilation_h * u1)[stride_w * dx4 + dilation_w * v1];
3167+
pp[10] = img0.row<const unsigned short>(stride_h * dy5 + dilation_h * u0)[stride_w * dx5 + dilation_w * v0];
3168+
pp[11] = img1.row<const unsigned short>(stride_h * dy5 + dilation_h * u1)[stride_w * dx5 + dilation_w * v1];
3169+
pp[12] = img0.row<const unsigned short>(stride_h * dy6 + dilation_h * u0)[stride_w * dx6 + dilation_w * v0];
3170+
pp[13] = img1.row<const unsigned short>(stride_h * dy6 + dilation_h * u1)[stride_w * dx6 + dilation_w * v1];
3171+
pp[14] = img0.row<const unsigned short>(stride_h * dy7 + dilation_h * u0)[stride_w * dx7 + dilation_w * v0];
3172+
pp[15] = img1.row<const unsigned short>(stride_h * dy7 + dilation_h * u1)[stride_w * dx7 + dilation_w * v1];
3173+
pp += 16;
3174+
}
3175+
}
3176+
#endif // __AVX512BF16__
31043177
for (; kk < max_kk / elempack; kk++)
31053178
{
31063179
int p = (k / elempack + kk) / maxk;
@@ -3311,6 +3384,41 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
33113384
if (dy0 == dy3)
33123385
{
33133386
int kk = 0;
3387+
#if __AVX512BF16__
3388+
if (elempack == 1)
3389+
{
3390+
for (; kk + 1 < max_kk; kk += 2)
3391+
{
3392+
int p0 = (k + kk) / maxk;
3393+
int uv0 = (k + kk) % maxk;
3394+
int u0 = uv0 / kernel_w;
3395+
int v0 = uv0 % kernel_w;
3396+
const Mat img0 = bottom_blob.channel(p0);
3397+
int sx0 = stride_w * dx0 + dilation_w * v0;
3398+
int sy0 = stride_h * dy0 + dilation_h * u0;
3399+
const unsigned short* sptr0 = img0.row<const unsigned short>(sy0) + sx0;
3400+
3401+
int p1 = (k + kk + 1) / maxk;
3402+
int uv1 = (k + kk + 1) % maxk;
3403+
int u1 = uv1 / kernel_w;
3404+
int v1 = uv1 % kernel_w;
3405+
const Mat img1 = bottom_blob.channel(p1);
3406+
int sx1 = stride_w * dx0 + dilation_w * v1;
3407+
int sy1 = stride_h * dy0 + dilation_h * u1;
3408+
const unsigned short* sptr1 = img1.row<const unsigned short>(sy1) + sx1;
3409+
3410+
pp[0] = sptr0[0];
3411+
pp[1] = sptr1[0];
3412+
pp[2] = sptr0[stride_w];
3413+
pp[3] = sptr1[stride_w];
3414+
pp[4] = sptr0[stride_w * 2];
3415+
pp[5] = sptr1[stride_w * 2];
3416+
pp[6] = sptr0[stride_w * 3];
3417+
pp[7] = sptr1[stride_w * 3];
3418+
pp += 8;
3419+
}
3420+
}
3421+
#endif // __AVX512BF16__
33143422
for (; kk < max_kk / elempack; kk++)
33153423
{
33163424
int p = (k / elempack + kk) / maxk;
@@ -3409,10 +3517,10 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34093517
#if __AVX512BF16__
34103518
__m128i _t0 = float2bfloat_sse(_r0);
34113519
__m128i _t1 = float2bfloat_sse(_r1);
3412-
_mm_storel_epi64((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
3520+
_mm_storeu_si128((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
34133521
__m128i _t2 = float2bfloat_sse(_r2);
34143522
__m128i _t3 = float2bfloat_sse(_r3);
3415-
_mm_storel_epi64((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
3523+
_mm_storeu_si128((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
34163524
#else
34173525

34183526
_mm_storel_epi64((__m128i*)pp, float2bfloat_sse(_r0));
@@ -3435,6 +3543,35 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
34353543
else
34363544
{
34373545
int kk = 0;
3546+
#if __AVX512BF16__
3547+
if (elempack == 1)
3548+
{
3549+
for (; kk + 1 < max_kk; kk += 2)
3550+
{
3551+
int p0 = (k + kk) / maxk;
3552+
int uv0 = (k + kk) % maxk;
3553+
int u0 = uv0 / kernel_w;
3554+
int v0 = uv0 % kernel_w;
3555+
const Mat img0 = bottom_blob.channel(p0);
3556+
3557+
int p1 = (k + kk + 1) / maxk;
3558+
int uv1 = (k + kk + 1) % maxk;
3559+
int u1 = uv1 / kernel_w;
3560+
int v1 = uv1 % kernel_w;
3561+
const Mat img1 = bottom_blob.channel(p1);
3562+
3563+
pp[0] = img0.row<const unsigned short>(stride_h * dy0 + dilation_h * u0)[stride_w * dx0 + dilation_w * v0];
3564+
pp[1] = img1.row<const unsigned short>(stride_h * dy0 + dilation_h * u1)[stride_w * dx0 + dilation_w * v1];
3565+
pp[2] = img0.row<const unsigned short>(stride_h * dy1 + dilation_h * u0)[stride_w * dx1 + dilation_w * v0];
3566+
pp[3] = img1.row<const unsigned short>(stride_h * dy1 + dilation_h * u1)[stride_w * dx1 + dilation_w * v1];
3567+
pp[4] = img0.row<const unsigned short>(stride_h * dy2 + dilation_h * u0)[stride_w * dx2 + dilation_w * v0];
3568+
pp[5] = img1.row<const unsigned short>(stride_h * dy2 + dilation_h * u1)[stride_w * dx2 + dilation_w * v1];
3569+
pp[6] = img0.row<const unsigned short>(stride_h * dy3 + dilation_h * u0)[stride_w * dx3 + dilation_w * v0];
3570+
pp[7] = img1.row<const unsigned short>(stride_h * dy3 + dilation_h * u1)[stride_w * dx3 + dilation_w * v1];
3571+
pp += 8;
3572+
}
3573+
}
3574+
#endif // __AVX512BF16__
34383575
for (; kk < max_kk / elempack; kk++)
34393576
{
34403577
int p = (k / elempack + kk) / maxk;
@@ -3542,10 +3679,10 @@ static inline void convolution_im2col_input_tile_impl_bf16s(const Mat& bottom_bl
35423679
#if __AVX512BF16__
35433680
__m128i _t0 = float2bfloat_sse(_r0);
35443681
__m128i _t1 = float2bfloat_sse(_r1);
3545-
_mm_storel_epi64((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
3682+
_mm_storeu_si128((__m128i*)(pp + 0), _mm_unpacklo_epi16(_t0, _t1));
35463683
__m128i _t2 = float2bfloat_sse(_r2);
35473684
__m128i _t3 = float2bfloat_sse(_r3);
3548-
_mm_storel_epi64((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
3685+
_mm_storeu_si128((__m128i*)(pp + 8), _mm_unpacklo_epi16(_t2, _t3));
35493686
#else
35503687

35513688
_mm_storel_epi64((__m128i*)pp, float2bfloat_sse(_r0));

0 commit comments

Comments
 (0)