Skip to content
5,533 changes: 3,664 additions & 1,869 deletions src/layer/x86/gemm_bf16s.h

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions src/layer/x86/gemm_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8518,7 +8518,7 @@ int Gemm_x86::create_pipeline_bf16s(const Option& opt)
return 0;
}

static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt)
{
const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w;

Expand Down Expand Up @@ -8587,14 +8587,14 @@ static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top
gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose);
unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype);
}
}

return 0;
}

static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;

Expand Down Expand Up @@ -8652,14 +8652,14 @@ static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose);
unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype);
}
}

return 0;
}

static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt)
{
int TILE_M, TILE_N, TILE_K;
get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
Expand Down Expand Up @@ -8693,14 +8693,14 @@ static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat&
gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose);
unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype);
}
}

return 0;
}

static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt)
{
const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;
Expand Down Expand Up @@ -8788,7 +8788,7 @@ static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose);
unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype);
}
}

Expand Down Expand Up @@ -8908,7 +8908,7 @@ int Gemm_x86::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
#endif // __SSE2__
if (output_elempack)
out_elempack = output_elempack;
size_t out_elemsize = 2u * out_elempack;
size_t out_elemsize = (output_elemtype == 1 ? 4u : 2u) * out_elempack;

Mat& top_blob = top_blobs[0];
if (output_transpose)
Expand Down Expand Up @@ -8937,23 +8937,23 @@ int Gemm_x86::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
int ret = 0;
if (constantA && constantB)
{
ret = gemm_AT_BT_x86_bf16s(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_AT_BT_x86_bf16s(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, output_elemtype, opt);
}
else if (constantA)
{
const Mat& B = bottom_blobs[0];
ret = gemm_AT_x86_bf16s(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_AT_x86_bf16s(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, output_elemtype, opt);
}
else if (constantB)
{
const Mat& A = bottom_blobs[0];
ret = gemm_BT_x86_bf16s(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_BT_x86_bf16s(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, output_elemtype, opt);
}
else
{
const Mat& A = bottom_blobs[0];
const Mat& B = bottom_blobs[1];
ret = gemm_x86_bf16s(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt);
ret = gemm_x86_bf16s(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, output_elemtype, opt);
}

return ret;
Expand Down
31 changes: 29 additions & 2 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ MultiHeadAttention_x86::MultiHeadAttention_x86()
support_packing = true;
#endif // __SSE2__

#if NCNN_BF16
support_bf16_storage = true;
#endif

q_gemm = 0;
k_gemm = 0;
v_gemm = 0;
Expand All @@ -31,6 +35,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt)
if (int8_scale_term)
{
support_packing = false;
support_bf16_storage = false;

opt.use_packing_layout = false; // TODO enable packing
}
Expand Down Expand Up @@ -180,7 +185,10 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt)
weights[2] = out_weight_data_int8_scales;
#endif
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);
Option opt_fp32 = opt;
opt_fp32.use_bf16_packed = false;
opt_fp32.use_bf16_storage = false;
o_gemm->create_pipeline(opt_fp32);

if (opt.lightmode)
{
Expand All @@ -203,12 +211,15 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt)
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(13, 1); // output_elemtype = fp32
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qk_gemm->load_param(pd);
qk_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
opt1.use_bf16_packed = false;
opt1.use_bf16_storage = false;
opt1.num_threads = 1;
qk_gemm->create_pipeline(opt1);
}
Expand All @@ -227,13 +238,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt)
pd.set(10, -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(13, 1); // output_elemtype = fp32
pd.set(14, 1); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qkv_gemm->load_param(pd);
qkv_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
opt1.use_bf16_packed = false;
opt1.use_bf16_storage = false;
opt1.num_threads = 1;
qkv_gemm->create_pipeline(opt1);
}
Expand Down Expand Up @@ -488,6 +502,17 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
return retv;
}

Mat v_affine_fp32 = v_affine;
#if NCNN_BF16
if (opt.use_bf16_storage && v_affine.elembits() == 16)
{
// qkv_gemm need fp32 inputs
cast_bfloat16_to_float32(v_affine, v_affine_fp32, opt);
if (v_affine_fp32.empty())
return -100;
}
#endif

Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt.blob_allocator);
if (qkv_cross.empty())
return -100;
Expand All @@ -499,7 +524,7 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
{
std::vector<Mat> qkv_bottom_blobs(2);
qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen);
qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head);
qkv_bottom_blobs[1] = v_affine_fp32.row_range(i * embed_dim_per_head, embed_dim_per_head);
std::vector<Mat> qkv_top_blobs(1);
qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head);
Option opt1 = opt;
Expand All @@ -512,6 +537,8 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
return retqkvs[i];
}

v_affine_fp32.release();

if (!kv_cache)
{
v_affine.release();
Expand Down
49 changes: 36 additions & 13 deletions src/layer/x86/sdpa_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ namespace ncnn {

SDPA_x86::SDPA_x86()
{
#if NCNN_BF16
support_bf16_storage = true;
#endif

qk_gemm = 0;
qkv_gemm = 0;
qk_softmax = 0;
Expand All @@ -20,6 +24,7 @@ int SDPA_x86::create_pipeline(const Option& _opt)
if (int8_scale_term)
{
opt.use_packing_layout = false; // TODO enable packing
support_bf16_storage = false;
}

{
Expand Down Expand Up @@ -51,6 +56,7 @@ int SDPA_x86::create_pipeline(const Option& _opt)
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(13, 1); // output_elemtype = fp32
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
Expand Down Expand Up @@ -78,6 +84,7 @@ int SDPA_x86::create_pipeline(const Option& _opt)
pd.set(10, -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(13, 1); // output_elemtype = fp32
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
Expand Down Expand Up @@ -148,10 +155,12 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
const int past_seqlen = kv_cache ? past_key.h : 0;
const int dst_seqlen = past_seqlen + cur_seqlen;

const size_t elemsize = query.elemsize;

Mat key;
if (past_seqlen > 0)
{
key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (key.empty())
return -100;

Expand All @@ -162,8 +171,8 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
const Mat cur_key_head = cur_key.channel(q);
Mat key_head = key.channel(q);

memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float));
memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float));
memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * elemsize);
memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * elemsize);
}
}
else
Expand All @@ -174,7 +183,7 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
Mat value;
if (past_seqlen > 0)
{
value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator);
if (value.empty())
return -100;

Expand All @@ -185,20 +194,15 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
const Mat cur_value_head = cur_value.channel(q);
Mat value_head = value.channel(q);

memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float));
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float));
memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * elemsize);
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize);
}
}
else
{
value = cur_value;
}

Mat& top_blob = top_blobs[0];
top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator);
if (top_blob.empty())
return -100;

const int num_heads_per_group = num_heads / num_group;

Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
Expand Down Expand Up @@ -229,6 +233,7 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(13, 1); // output_elemtype = fp32
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
Expand Down Expand Up @@ -290,15 +295,31 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
if (retqk != 0)
return retqk;

Mat value_fp32 = value;
#if NCNN_BF16
if (opt.use_bf16_storage && value.elembits() == 16)
{
// qkv_gemm need fp32 inputs
cast_bfloat16_to_float32(value, value_fp32, opt);
if (value_fp32.empty())
return -100;
}
#endif

Mat& top_blob = top_blobs[0];
top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator);
if (top_blob.empty())
return -100;

// 3. Attn * V
std::vector<int> retqkvs(num_heads);

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < num_heads; i++)
{
std::vector<Mat> qkv_bottom_blobs(2);
qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq]
qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed]
qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq]
qkv_bottom_blobs[1] = value_fp32.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed]

std::vector<Mat> qkv_top_blobs(1);
qkv_top_blobs[0] = top_blob.channel(i); // Output
Expand All @@ -314,6 +335,8 @@ int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
return retqkvs[i];
}

value_fp32.release();

if (kv_cache)
{
top_blobs[1] = key;
Expand Down
6 changes: 5 additions & 1 deletion tests/test_gemm_2e.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ int main()
{44, 19, 7},
{47, 35, 48},
{47, 48, 47},
{48, 35, 47}
{48, 35, 47},
{32, 24, 5},
{20, 24, 5},
{32, 20, 5},
{24, 20, 5},
};

int mnk_count = sizeof(mnk) / sizeof(int) / 3;
Expand Down
Loading
Loading