Skip to content

Commit 5bfcbfd

Browse files
authored
rotaryembed/tanh/selu/mish/hardswish/hardsigmoid/gelu/erf/elu/eltwise/dropout/quantize/dequantize/bnll x86 support bf16 storage (#6624)
1 parent 371bbad commit 5bfcbfd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+3523
-1
lines changed

src/layer/x86/bnll_bf16s.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright 2026 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
5+
void bnll_bf16s_avx512bf16(Mat& a, const Option& opt);
6+
#endif
7+
8+
static void bnll_bf16s(Mat& a, const Option& opt)
9+
{
10+
#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__
11+
if (ncnn::cpu_support_x86_avx512_bf16())
12+
{
13+
bnll_bf16s_avx512bf16(a, opt);
14+
return;
15+
}
16+
#endif
17+
18+
int w = a.w;
19+
int h = a.h;
20+
int d = a.d;
21+
int channels = a.c;
22+
int elempack = a.elempack;
23+
int size = w * h * d * elempack;
24+
25+
#pragma omp parallel for num_threads(opt.num_threads)
26+
for (int q = 0; q < channels; q++)
27+
{
28+
unsigned short* ptr = a.channel(q);
29+
30+
int i = 0;
31+
#if __SSE2__
32+
#if __AVX__
33+
#if __AVX512F__
34+
__m512 _one_avx512 = _mm512_set1_ps(1.f);
35+
__m512 _zero_avx512 = _mm512_setzero_ps();
36+
for (; i + 15 < size; i += 16)
37+
{
38+
__m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr));
39+
__mmask16 mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_GT_OQ);
40+
__m512 _abs_p = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(_p), _mm512_set1_epi32(0x7fffffff)));
41+
__m512 _tmp = log512_ps(_mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _abs_p))));
42+
_p = _mm512_mask_add_ps(_tmp, mask, _tmp, _p);
43+
_mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p));
44+
ptr += 16;
45+
}
46+
if (i < size)
47+
{
48+
const unsigned int remain = size - i;
49+
__mmask16 _mask = (__mmask16)((1u << remain) - 1);
50+
__m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr));
51+
__mmask16 mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_GT_OQ);
52+
__m512 _abs_p = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(_p), _mm512_set1_epi32(0x7fffffff)));
53+
__m512 _tmp = log512_ps(_mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _abs_p))));
54+
_p = _mm512_mask_add_ps(_tmp, mask, _tmp, _p);
55+
_mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p));
56+
i += remain;
57+
}
58+
#else // __AVX512F__
59+
__m256 _one_avx = _mm256_set1_ps(1.f);
60+
__m256 _zero_avx = _mm256_setzero_ps();
61+
for (; i + 7 < size; i += 8)
62+
{
63+
__m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr));
64+
__m256 mask = _mm256_cmp_ps(_p, _mm256_setzero_ps(), _CMP_GT_OQ);
65+
__m256 _abs_p = _mm256_and_ps(_p, *(__m256*)_ps256_inv_sign_mask);
66+
__m256 _tmp = log256_ps(_mm256_add_ps(_one_avx, exp256_ps(_mm256_sub_ps(_zero_avx, _abs_p))));
67+
__m256 _x = _mm256_and_ps(_p, mask);
68+
_p = _mm256_add_ps(_x, _tmp);
69+
_mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p));
70+
ptr += 8;
71+
}
72+
__m128 _one = _mm_set1_ps(1.f);
73+
__m128 _zero = _mm_setzero_ps();
74+
for (; i + 3 < size; i += 4)
75+
{
76+
__m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr));
77+
__m128 mask = _mm_cmpgt_ps(_p, _zero);
78+
__m128 _abs_p = _mm_and_ps(_p, *(__m128*)_ps_inv_sign_mask);
79+
__m128 _tmp = log_ps(_mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _abs_p))));
80+
__m128 _x = _mm_and_ps(_p, mask);
81+
_p = _mm_add_ps(_x, _tmp);
82+
_mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p));
83+
ptr += 4;
84+
}
85+
#endif // __AVX512F__
86+
#else // __AVX__
87+
__m128 _one = _mm_set1_ps(1.f);
88+
__m128 _zero = _mm_setzero_ps();
89+
for (; i + 3 < size; i += 4)
90+
{
91+
__m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr));
92+
__m128 mask = _mm_cmpgt_ps(_p, _zero);
93+
__m128 _abs_p = _mm_and_ps(_p, *(__m128*)_ps_inv_sign_mask);
94+
__m128 _tmp = log_ps(_mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _abs_p))));
95+
__m128 _x = _mm_and_ps(_p, mask);
96+
_p = _mm_add_ps(_x, _tmp);
97+
_mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p));
98+
ptr += 4;
99+
}
100+
#endif // __AVX__
101+
#endif // __SSE2__
102+
for (; i < size; i++)
103+
{
104+
float v = bfloat16_to_float32(*ptr);
105+
if (v > 0)
106+
v = v + logf(1.f + expf(-v));
107+
else
108+
v = logf(1.f + expf(v));
109+
*ptr = float32_to_bfloat16(v);
110+
ptr++;
111+
}
112+
}
113+
}

src/layer/x86/bnll_x86.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,24 @@
1515
#endif // __AVX__
1616
#endif // __SSE2__
1717

18+
#include "x86_usability.h"
19+
20+
#include "cpu.h"
21+
1822
namespace ncnn {
1923

24+
#if NCNN_BF16
25+
#include "bnll_bf16s.h"
26+
#endif
27+
2028
BNLL_x86::BNLL_x86()
2129
{
2230
#if __SSE2__
2331
support_packing = true;
2432
#endif // __SSE2__
33+
#if NCNN_BF16
34+
support_bf16_storage = true;
35+
#endif
2536
}
2637

2738
int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
@@ -33,6 +44,11 @@ int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
3344
int elempack = bottom_top_blob.elempack;
3445
int size = w * h * d * elempack;
3546

47+
#if NCNN_BF16
48+
if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16)
49+
return forward_inplace_bf16s(bottom_top_blob, opt);
50+
#endif
51+
3652
#pragma omp parallel for num_threads(opt.num_threads)
3753
for (int q = 0; q < channels; q++)
3854
{
@@ -95,4 +111,13 @@ int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
95111
return 0;
96112
}
97113

114+
#if NCNN_BF16
115+
int BNLL_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
116+
{
117+
bnll_bf16s(bottom_top_blob, opt);
118+
119+
return 0;
120+
}
121+
#endif // NCNN_BF16
122+
98123
} // namespace ncnn

src/layer/x86/bnll_x86.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ class BNLL_x86 : public BNLL
1414
BNLL_x86();
1515
virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;
1616

17-
public:
17+
protected:
18+
#if NCNN_BF16
19+
int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const;
20+
#endif
1821
};
1922

2023
} // namespace ncnn
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2026 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "bnll_x86.h"
5+
6+
#if __SSE2__
7+
#include <emmintrin.h>
8+
#include "sse_mathfun.h"
9+
#if __AVX__
10+
#include <immintrin.h>
11+
#include "avx_mathfun.h"
12+
#if __AVX512F__
13+
#include "avx512_mathfun.h"
14+
#endif // __AVX512F__
15+
#endif // __AVX__
16+
#endif // __SSE2__
17+
18+
#include "x86_usability.h"
19+
20+
#include "cpu.h"
21+
#include "mat.h"
22+
23+
namespace ncnn {
24+
25+
#include "bnll_bf16s.h"
26+
27+
void bnll_bf16s_avx512bf16(Mat& a, const Option& opt)
28+
{
29+
bnll_bf16s(a, opt);
30+
}
31+
32+
} // namespace ncnn

0 commit comments

Comments
 (0)