From 78b9ae6b49b6300013ac31943fa47da8d7cefc90 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Thu, 26 Feb 2026 17:10:43 +0800 Subject: [PATCH 01/36] sdpa --- src/layer/riscv/sdpa_riscv.cpp | 380 +++++++++++++++++++++++++++++++++ src/layer/riscv/sdpa_riscv.h | 30 +++ 2 files changed, 410 insertions(+) create mode 100644 src/layer/riscv/sdpa_riscv.cpp create mode 100644 src/layer/riscv/sdpa_riscv.h diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp new file mode 100644 index 00000000000..3dfe9525744 --- /dev/null +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -0,0 +1,380 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "sdpa_riscv.h" + +#if __riscv_vector +#include +#endif + +#include "layer_type.h" + +namespace ncnn { + +SDPA_riscv::SDPA_riscv() +{ + qk_gemm = 0; + qkv_gemm = 0; + qk_softmax = 0; +} + +int SDPA_riscv::create_pipeline(const Option& _opt) +{ + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + + { + qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); + ncnn::ParamDict pd; + pd.set(0, -1); // axis + pd.set(1, 1); + qk_softmax->load_param(pd); + qk_softmax->load_model(ModelBinFromMatArray(0)); + qk_softmax->create_pipeline(opt); + } + + // Q * K^T + if (scale != 0.f) + { + qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + + pd.set(0, scale); // alpha + pd.set(1, 1.f / scale); // beta + pd.set(2, 0); // transA (Q: Seq x Embed) + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif + qk_gemm->load_param(pd); + qk_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = opt; + opt1.num_threads = 1; + qk_gemm->create_pipeline(opt1); + } + + // Attn * V + { + qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(0, 1.f); // alpha + pd.set(1, 1.f); // beta + pd.set(2, 0); // transA (Attn: Seq x Seq) + pd.set(3, 0); // transB (V: Seq x Embed) => Attn * V + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack + pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif + qkv_gemm->load_param(pd); + qkv_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = opt; + opt1.num_threads = 1; + qkv_gemm->create_pipeline(opt1); + } + + return 0; +} + +int SDPA_riscv::destroy_pipeline(const Option& _opt) +{ + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + + if (qk_softmax) + { + qk_softmax->destroy_pipeline(opt); + delete qk_softmax; + qk_softmax = 0; + } + + if (qk_gemm) + { + qk_gemm->destroy_pipeline(opt); + delete qk_gemm; + qk_gemm = 0; + } + + if (qkv_gemm) + { + qkv_gemm->destroy_pipeline(opt); + delete qkv_gemm; + qkv_gemm = 0; + } + + return 0; +} + +int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const +{ + const Mat& query = bottom_blobs[0]; + int elempack = query.elempack; + + if (elempack > 1) + { + // fallback: unpack -> forward -> repack + Option opt = _opt; + opt.blob_allocator = _opt.workspace_allocator; + + std::vector unpacked_bottom_blobs(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) + { + if (bottom_blobs[i].empty()) continue; + + if (bottom_blobs[i].elempack == 1) + { + unpacked_bottom_blobs[i] = bottom_blobs[i]; + } + else + { + ncnn::Layer* packing = ncnn::create_layer_cpu(ncnn::LayerType::Packing); + ncnn::ParamDict pd; + pd.set(0, 1); // out_elempack + packing->load_param(pd); + packing->forward(bottom_blobs[i], unpacked_bottom_blobs[i], opt); + delete packing; + } + } + + std::vector unpacked_top_blobs(top_blobs.size()); + + // call forward with elempack=1 + int ret = forward(unpacked_bottom_blobs, unpacked_top_blobs, _opt); + if (ret != 0) return ret; + + // repack outputs + for (size_t i = 0; i < top_blobs.size(); i++) + { + if (unpacked_top_blobs[i].empty()) continue; + + ncnn::Layer* packing = ncnn::create_layer_cpu(ncnn::LayerType::Packing); + ncnn::ParamDict pd; + pd.set(0, elempack); // out_elempack + packing->load_param(pd); + packing->forward(unpacked_top_blobs[i], top_blobs[i], _opt); // use original allocator for output + delete packing; + } + + return 0; + } + + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + + const Mat& cur_key = bottom_blobs[1]; + const Mat& cur_value = bottom_blobs[2]; + const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); + const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + + const int embed_dim = query.w; + const int src_seqlen = query.h; + const int num_heads = query.c; + const int cur_seqlen = cur_key.h; + const int num_group = cur_key.c; + const int out_embed_dim = cur_value.w; + const int past_seqlen = kv_cache ? past_key.h : 0; + const int dst_seqlen = past_seqlen + cur_seqlen; + + Mat key; + if (past_seqlen > 0) + { + key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (key.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + const Mat past_key_head = past_key.channel(q); + const Mat cur_key_head = cur_key.channel(q); + Mat key_head = key.channel(q); + + memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); + memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); + } + } + else + { + key = cur_key; + } + + Mat value; + if (past_seqlen > 0) + { + value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + if (value.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + const Mat past_value_head = past_value.channel(q); + const Mat cur_value_head = cur_value.channel(q); + Mat value_head = value.channel(q); + + memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); + memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); + } + } + else + { + value = cur_value; + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const int num_heads_per_group = num_heads / num_group; + + Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); + if (qk_cross.empty()) + return -100; + + std::vector retqks(num_heads); + + // Dynamic Scale Calculation and Beta Correction + Layer* _qk_gemm = qk_gemm; + if (scale == 0.f) + { + float _scale = 1.f / sqrt(embed_dim); + + _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + + pd.set(0, _scale); // alpha + pd.set(1, 1.f / _scale); // beta + pd.set(2, 0); // transA (Q: Seq x Embed) + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif + _qk_gemm->load_param(pd); + _qk_gemm->load_model(ModelBinFromMatArray(0)); + + Option opt1 = opt; + opt1.num_threads = 1; + _qk_gemm->create_pipeline(opt1); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < num_heads; i++) + { + // 1. Q * K^T + std::vector qk_bottom_blobs; + qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] + qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] + + if (attn_mask) + { + // Ensure mask is 2D for Gemm auto-broadcast detection + Mat maskm = attn_mask_blob; + if (maskm.dims == 3) + { + // If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast) + maskm = maskm.channel(maskm.c > 1 ? i : 0); + } + qk_bottom_blobs.push_back(maskm); + } + + std::vector qk_top_blobs(1); + qk_top_blobs[0] = qk_cross.channel(i); + + Option opt1 = opt; + opt1.num_threads = 1; + opt1.blob_allocator = qk_cross.allocator; + retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); + } + + if (scale == 0.f) + { + Option opt1 = opt; + opt1.num_threads = 1; + _qk_gemm->destroy_pipeline(opt1); + + delete _qk_gemm; + _qk_gemm = 0; + } + + for (int i = 0; i < num_heads; i++) + { + if (retqks[i] != 0) + return retqks[i]; + } + + // 2. Softmax + int retqk = qk_softmax->forward_inplace(qk_cross, opt); + if (retqk != 0) + return retqk; + + // 3. Attn * V + std::vector retqkvs(num_heads); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < num_heads; i++) + { + std::vector qkv_bottom_blobs(2); + qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] + qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] + + std::vector qkv_top_blobs(1); + qkv_top_blobs[0] = top_blob.channel(i); // Output + + Option opt1 = opt; + opt1.num_threads = 1; + retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); + } + + for (int i = 0; i < num_heads; i++) + { + if (retqkvs[i] != 0) + return retqkvs[i]; + } + + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/riscv/sdpa_riscv.h b/src/layer/riscv/sdpa_riscv.h new file mode 100644 index 00000000000..23cf95b6e79 --- /dev/null +++ b/src/layer/riscv/sdpa_riscv.h @@ -0,0 +1,30 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_SDPA_RISCV_H +#define LAYER_SDPA_RISCV_H + +#include "sdpa.h" + +namespace ncnn { + +class SDPA_riscv : public SDPA +{ +public: + SDPA_riscv(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + Layer* qk_gemm; + Layer* qkv_gemm; + + Layer* qk_softmax; +}; + +} // namespace ncnn + +#endif // LAYER_SDPA_RISCV_H From 9d189bf414dcae06e001c7e422cc87e571d19c04 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Wed, 1 Apr 2026 11:40:14 +0800 Subject: [PATCH 02/36] WIP: save local changes before rebase --- src/layer/riscv/sdpa_riscv.cpp | 300 +++++++++++++++++++-------------- src/layer/riscv/sdpa_riscv.h | 2 +- 2 files changed, 172 insertions(+), 130 deletions(-) diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp index 3dfe9525744..de01f447212 100644 --- a/src/layer/riscv/sdpa_riscv.cpp +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -1,31 +1,28 @@ -// Copyright 2025 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "sdpa_riscv.h" +#include "layer_type.h" + #if __riscv_vector #include #endif - -#include "layer_type.h" +#include "riscv_usability.h" namespace ncnn { SDPA_riscv::SDPA_riscv() { + support_packing = true; + qk_gemm = 0; qkv_gemm = 0; qk_softmax = 0; } -int SDPA_riscv::create_pipeline(const Option& _opt) +int SDPA_riscv::create_pipeline(const Option& opt) { - Option opt = _opt; - if (int8_scale_term) - { - opt.use_packing_layout = false; // TODO enable packing - } - { qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); ncnn::ParamDict pd; @@ -37,22 +34,21 @@ int SDPA_riscv::create_pipeline(const Option& _opt) } // Q * K^T - if (scale != 0.f) { qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, scale); // alpha - pd.set(1, 1.f / scale); // beta + pd.set(0, 1.f); // alpha (will be set in forward) + pd.set(1, 0.f); // beta pd.set(2, 0); // transA (Q: Seq x Embed) pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T pd.set(4, 0); // constantA pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) + pd.set(6, 1); // constantC (None) pd.set(7, 0); // M pd.set(8, 0); // N pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) + pd.set(10, -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack #if NCNN_INT8 @@ -96,14 +92,8 @@ int SDPA_riscv::create_pipeline(const Option& _opt) return 0; } -int SDPA_riscv::destroy_pipeline(const Option& _opt) +int SDPA_riscv::destroy_pipeline(const Option& opt) { - Option opt = _opt; - if (int8_scale_term) - { - opt.use_packing_layout = false; // TODO enable packing - } - if (qk_softmax) { qk_softmax->destroy_pipeline(opt); @@ -130,78 +120,80 @@ int SDPA_riscv::destroy_pipeline(const Option& _opt) int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { + Option opt = _opt; const Mat& query = bottom_blobs[0]; - int elempack = query.elempack; + const Mat& cur_key = bottom_blobs[1]; + const Mat& cur_value = bottom_blobs[2]; + const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); + const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + + const int embed_dim = query.w; + const int src_seqlen = query.h; + const int num_heads = query.c; + const int cur_seqlen = cur_key.h; + const int num_group = cur_key.c; + const int out_embed_dim = cur_value.w; + const int past_seqlen = kv_cache ? past_key.h : 0; + const int dst_seqlen = past_seqlen + cur_seqlen; + const int elempack = query.elempack; if (elempack > 1) { - // fallback: unpack -> forward -> repack - Option opt = _opt; - opt.blob_allocator = _opt.workspace_allocator; + // Fallback for packed data + // TODO: Implement optimized RVV paths for group=2 with elempack=2,4,8, and group=4 with elempack=4 + + // Unpack input blobs + std::vector bottom_blobs_unpacked = bottom_blobs; + Option opt_unpack = opt; + opt_unpack.blob_allocator = opt.workspace_allocator; + + Mat query_unpacked; + convert_packing(query, query_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[0] = query_unpacked; + + Mat cur_key_unpacked; + convert_packing(cur_key, cur_key_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[1] = cur_key_unpacked; + + Mat cur_value_unpacked; + convert_packing(cur_value, cur_value_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[2] = cur_value_unpacked; - std::vector unpacked_bottom_blobs(bottom_blobs.size()); - for (size_t i = 0; i < bottom_blobs.size(); i++) + if (attn_mask) { - if (bottom_blobs[i].empty()) continue; - - if (bottom_blobs[i].elempack == 1) - { - unpacked_bottom_blobs[i] = bottom_blobs[i]; - } - else - { - ncnn::Layer* packing = ncnn::create_layer_cpu(ncnn::LayerType::Packing); - ncnn::ParamDict pd; - pd.set(0, 1); // out_elempack - packing->load_param(pd); - packing->forward(bottom_blobs[i], unpacked_bottom_blobs[i], opt); - delete packing; - } + Mat attn_mask_unpacked; + convert_packing(attn_mask_blob, attn_mask_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[3] = attn_mask_unpacked; } - std::vector unpacked_top_blobs(top_blobs.size()); + if (kv_cache) + { + Mat past_key_unpacked; + convert_packing(past_key, past_key_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[attn_mask ? 4 : 3] = past_key_unpacked; + + Mat past_value_unpacked; + convert_packing(past_value, past_value_unpacked, 1, opt_unpack); + bottom_blobs_unpacked[attn_mask ? 5 : 4] = past_value_unpacked; + } - // call forward with elempack=1 - int ret = forward(unpacked_bottom_blobs, unpacked_top_blobs, _opt); - if (ret != 0) return ret; + std::vector top_blobs_unpacked(top_blobs.size()); + int ret = SDPA::forward(bottom_blobs_unpacked, top_blobs_unpacked, opt); + if (ret != 0) + return ret; - // repack outputs + // Repack output blobs for (size_t i = 0; i < top_blobs.size(); i++) { - if (unpacked_top_blobs[i].empty()) continue; - - ncnn::Layer* packing = ncnn::create_layer_cpu(ncnn::LayerType::Packing); - ncnn::ParamDict pd; - pd.set(0, elempack); // out_elempack - packing->load_param(pd); - packing->forward(unpacked_top_blobs[i], top_blobs[i], _opt); // use original allocator for output - delete packing; + if (top_blobs_unpacked[i].empty()) + continue; + convert_packing(top_blobs_unpacked[i], top_blobs[i], elempack, opt); } return 0; } - Option opt = _opt; - if (int8_scale_term) - { - opt.use_packing_layout = false; // TODO enable packing - } - - const Mat& cur_key = bottom_blobs[1]; - const Mat& cur_value = bottom_blobs[2]; - const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); - - const int embed_dim = query.w; - const int src_seqlen = query.h; - const int num_heads = query.c; - const int cur_seqlen = cur_key.h; - const int num_group = cur_key.c; - const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; - const int dst_seqlen = past_seqlen + cur_seqlen; - Mat key; if (past_seqlen > 0) { @@ -261,26 +253,30 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& std::vector retqks(num_heads); - // Dynamic Scale Calculation and Beta Correction - Layer* _qk_gemm = qk_gemm; - if (scale == 0.f) + float _scale = scale; + if (_scale == 0.f) { - float _scale = 1.f / sqrt(embed_dim); + _scale = 1.f / sqrt(embed_dim); + } + // Create local Gemm if scale is dynamic or different from 1.f + Layer* _qk_gemm = qk_gemm; + bool local_gemm = false; + if (_scale != 1.f) + { _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, _scale); // alpha - pd.set(1, 1.f / _scale); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA + pd.set(3, 1); // transB pd.set(4, 0); // constantA pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) + pd.set(6, 1); // constantC (None) pd.set(7, 0); // M pd.set(8, 0); // N pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) + pd.set(10, -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack #if NCNN_INT8 @@ -288,55 +284,76 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& #endif _qk_gemm->load_param(pd); _qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; opt1.num_threads = 1; _qk_gemm->create_pipeline(opt1); + local_gemm = true; } #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { // 1. Q * K^T - std::vector qk_bottom_blobs; - qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] - qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] + const Mat q_head = query.channel(i); + const Mat k_head = key.channel(i / num_heads_per_group); + Mat score_head = qk_cross.channel(i); - if (attn_mask) + for (int j = 0; j < src_seqlen; j++) { - // Ensure mask is 2D for Gemm auto-broadcast detection - Mat maskm = attn_mask_blob; - if (maskm.dims == 3) + const float* qptr = q_head.row(j); + float* outptr = score_head.row(j); + const float* mptr_row = 0; + if (attn_mask) { - // If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast) - maskm = maskm.channel(maskm.c > 1 ? i : 0); + const Mat& maskm = attn_mask_blob.c > 1 ? attn_mask_blob.channel(i) : attn_mask_blob; + mptr_row = maskm.row(j); } - qk_bottom_blobs.push_back(maskm); - } - std::vector qk_top_blobs(1); - qk_top_blobs[0] = qk_cross.channel(i); + for (int k = 0; k < dst_seqlen; k++) + { + const float* kptr = k_head.row(k); + float sum = 0.f; - Option opt1 = opt; - opt1.num_threads = 1; - opt1.blob_allocator = qk_cross.allocator; - retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); +#if __riscv_vector + size_t vlmax = __riscv_vsetvlmax_e32m8(); + vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); + int l = 0; + for (; l < embed_dim; ) + { + size_t vl = __riscv_vsetvl_e32m8(embed_dim - l); + vfloat32m8_t _q = __riscv_vle32_v_f32m8(qptr + l, vl); + vfloat32m8_t _k = __riscv_vle32_v_f32m8(kptr + l, vl); + _sum_v = __riscv_vfmacc_vv_f32m8(_sum_v, _q, _k, vl); + l += vl; + } + vfloat32m1_t _sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1); + _sum_scalar = __riscv_vfredusum_vs_f32m8_f32m1(_sum_v, _sum_scalar, vlmax); + sum = __riscv_vfmv_f_s_f32m1_f32(_sum_scalar); +#else + for (int l = 0; l < embed_dim; l++) + { + sum += qptr[l] * kptr[l]; + } +#endif + outptr[k] = sum * _scale; + if (attn_mask) + outptr[k] += mptr_row[k]; + } + } } - if (scale == 0.f) + for (int i = 0; i < num_heads; i++) + { + if (retqks[i] != 0) + return retqks[i]; + } + + if (local_gemm) { Option opt1 = opt; opt1.num_threads = 1; _qk_gemm->destroy_pipeline(opt1); - delete _qk_gemm; - _qk_gemm = 0; - } - - for (int i = 0; i < num_heads; i++) - { - if (retqks[i] != 0) - return retqks[i]; } // 2. Softmax @@ -345,28 +362,52 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& return retqk; // 3. Attn * V - std::vector retqkvs(num_heads); - #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { - std::vector qkv_bottom_blobs(2); - qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] - qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] + const Mat score_head = qk_cross.channel(i); + const Mat v_head = value.channel(i / num_heads_per_group); + Mat out_head = top_blob.channel(i); - std::vector qkv_top_blobs(1); - qkv_top_blobs[0] = top_blob.channel(i); // Output + for (int j = 0; j < src_seqlen; j++) + { + const float* qkptr = score_head.row(j); + float* outptr = out_head.row(j); - Option opt1 = opt; - opt1.num_threads = 1; - retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); + for (int k = 0; k < out_embed_dim; k++) + { + float sum = 0.f; +#if __riscv_vector + size_t vlmax = __riscv_vsetvlmax_e32m8(); + vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); + int l = 0; + for (; l < dst_seqlen; ) + { + size_t vl = __riscv_vsetvl_e32m8(dst_seqlen - l); + vfloat32m8_t _qk = __riscv_vle32_v_f32m8(qkptr + l, vl); + vfloat32m8_t _v = __riscv_vlse32_v_f32m8(v_head.row(l) + k, out_embed_dim * sizeof(float), vl); + _sum_v = __riscv_vfmacc_vv_f32m8(_sum_v, _qk, _v, vl); + l += vl; + } + vfloat32m1_t _sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1); + _sum_scalar = __riscv_vfredusum_vs_f32m8_f32m1(_sum_v, _sum_scalar, vlmax); + sum = __riscv_vfmv_f_s_f32m1_f32(_sum_scalar); +#else + for (int l = 0; l < dst_seqlen; l++) + { + sum += qkptr[l] * v_head.row(l)[k]; + } +#endif + outptr[k] = sum; + } + } } - for (int i = 0; i < num_heads; i++) - { - if (retqkvs[i] != 0) - return retqkvs[i]; - } + // for (int i = 0; i < num_heads; i++) + // { + // if (retqkvs[i] != 0) + // return retqkvs[i]; + // } if (kv_cache) { @@ -378,3 +419,4 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& } } // namespace ncnn + diff --git a/src/layer/riscv/sdpa_riscv.h b/src/layer/riscv/sdpa_riscv.h index 23cf95b6e79..796a31b3eae 100644 --- a/src/layer/riscv/sdpa_riscv.h +++ b/src/layer/riscv/sdpa_riscv.h @@ -1,4 +1,4 @@ -// Copyright 2025 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_SDPA_RISCV_H From ec9ef0eb7afba5aacfb22931ae523dc8a5957397 Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Wed, 1 Apr 2026 04:17:15 +0000 Subject: [PATCH 03/36] apply code-format changes --- src/layer/riscv/sdpa_riscv.cpp | 59 +++++++++++++++++----------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp index de01f447212..b4d63e31566 100644 --- a/src/layer/riscv/sdpa_riscv.cpp +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -38,19 +38,19 @@ int SDPA_riscv::create_pipeline(const Option& opt) qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, 1.f); // alpha (will be set in forward) - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, 1.f); // alpha (will be set in forward) + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA (Q: Seq x Embed) + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -142,7 +142,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& { // Fallback for packed data // TODO: Implement optimized RVV paths for group=2 with elempack=2,4,8, and group=4 with elempack=4 - + // Unpack input blobs std::vector bottom_blobs_unpacked = bottom_blobs; Option opt_unpack = opt; @@ -266,19 +266,19 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& { _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, _scale); // alpha - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, _scale); // alpha + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -318,7 +318,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < embed_dim; ) + for (; l < embed_dim;) { size_t vl = __riscv_vsetvl_e32m8(embed_dim - l); vfloat32m8_t _q = __riscv_vle32_v_f32m8(qptr + l, vl); @@ -381,7 +381,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < dst_seqlen; ) + for (; l < dst_seqlen;) { size_t vl = __riscv_vsetvl_e32m8(dst_seqlen - l); vfloat32m8_t _qk = __riscv_vle32_v_f32m8(qkptr + l, vl); @@ -419,4 +419,3 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& } } // namespace ncnn - From 50c13d52e8b60213badaa316d9588e374cb35bed Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 6 Mar 2026 14:37:37 +0800 Subject: [PATCH 04/36] implement vulkan gemm packed (#6573) * drop out pad --- src/layer/vulkan/gemm_vulkan.cpp | 532 +++- src/layer/vulkan/shader/gemm.comp | 508 +++- src/layer/vulkan/shader/gemm_cm.comp | 4177 ++++++++++++++++---------- src/layer/vulkan/shader/gemm_sg.comp | 256 +- 4 files changed, 3606 insertions(+), 1867 deletions(-) diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 0b6cc390bc9..06cb843d3e7 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -99,7 +99,344 @@ int Gemm_vulkan::create_pipeline(const Option& opt) UNROLL_WG_M = std::min((M + coopmat_M * UNROLL_SG_M - 1) / (coopmat_M * UNROLL_SG_M), 2); UNROLL_WG_N = std::min((N + coopmat_N * UNROLL_SG_N - 1) / (coopmat_N * UNROLL_SG_N), 2); - std::vector specializations(15 + 9); + if (constantA == 1) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // +-K-+ + // M | + // +SG_UM + // | | + // ^ +---+ + // | | | + // WG_UM+- -+ + // | | | + // v +---+ + + const int blocks_m = (M + coopmat_M * UNROLL_SG_M * UNROLL_WG_M - 1) / (coopmat_M * UNROLL_SG_M * UNROLL_WG_M); + const int kk = (K + coopmat_K - 1) / coopmat_K; + + A_data_packed.create(coopmat_M * coopmat_K * UNROLL_SG_M * UNROLL_WG_M * kk, blocks_m); + + if (transA == 0) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int bm = 0; bm < blocks_m; bm++) + { + float* p = A_data_packed.row(bm); + + int k = 0; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + for (int wm = 0; wm < UNROLL_WG_M; wm++) + { + for (int zk = 0; zk < UNROLL_SG_K; zk++) + { + for (int zm = 0; zm < UNROLL_SG_M; zm++) + { + for (int i = 0; i < coopmat_M; i++) + { + for (int j = 0; j < coopmat_K; j++) + { + const int gmi = ((bm * UNROLL_WG_M + wm) * UNROLL_SG_M + zm) * coopmat_M + i; + const int gki = (k + zk) * coopmat_K + j; + + if (gmi < M && gki < K) + { + *p++ = A_data[gmi * K + gki]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + for (; k < kk; k++) + { + for (int wm = 0; wm < UNROLL_WG_M; wm++) + { + for (int zm = 0; zm < UNROLL_SG_M; zm++) + { + for (int i = 0; i < coopmat_M; i++) + { + for (int j = 0; j < coopmat_K; j++) + { + const int gmi = ((bm * UNROLL_WG_M + wm) * UNROLL_SG_M + zm) * coopmat_M + i; + const int gki = k * coopmat_K + j; + + if (gmi < M && gki < K) + { + *p++ = A_data[gmi * K + gki]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int bm = 0; bm < blocks_m; bm++) + { + float* p = A_data_packed.row(bm); + + int k = 0; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + for (int wm = 0; wm < UNROLL_WG_M; wm++) + { + for (int zk = 0; zk < UNROLL_SG_K; zk++) + { + for (int zm = 0; zm < UNROLL_SG_M; zm++) + { + for (int i = 0; i < coopmat_M; i++) + { + for (int j = 0; j < coopmat_K; j++) + { + const int gmi = ((bm * UNROLL_WG_M + wm) * UNROLL_SG_M + zm) * coopmat_M + i; + const int gki = (k + zk) * coopmat_K + j; + + if (gmi < M && gki < K) + { + *p++ = A_data[gki * M + gmi]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + for (; k < kk; k++) + { + for (int wm = 0; wm < UNROLL_WG_M; wm++) + { + for (int zm = 0; zm < UNROLL_SG_M; zm++) + { + for (int i = 0; i < coopmat_M; i++) + { + for (int j = 0; j < coopmat_K; j++) + { + const int gmi = ((bm * UNROLL_WG_M + wm) * UNROLL_SG_M + zm) * coopmat_M + i; + const int gki = k * coopmat_K + j; + + if (gmi < M && gki < K) + { + *p++ = A_data[gki * M + gmi]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + } + } + + if (constantB == 1) + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // WG_UN+- -+ + // | | | + // v +---+ + + const int blocks_n = (N + coopmat_N * UNROLL_SG_N * UNROLL_WG_N - 1) / (coopmat_N * UNROLL_SG_N * UNROLL_WG_N); + const int kk = (K + coopmat_K - 1) / coopmat_K; + + B_data_packed.create(coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n); + + if (transB == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int bn = 0; bn < blocks_n; bn++) + { + float* p = B_data_packed.row(bn); + + int k = 0; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + for (int wn = 0; wn < UNROLL_WG_N; wn++) + { + for (int zk = 0; zk < UNROLL_SG_K; zk++) + { + for (int zn = 0; zn < UNROLL_SG_N; zn++) + { + for (int i = 0; i < coopmat_K; i++) + { + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + const int gki = (k + zk) * coopmat_K + i; + + if (gni < N && gki < K) + { + *p++ = B_data[gni * K + gki]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + for (; k < kk; k++) + { + for (int wn = 0; wn < UNROLL_WG_N; wn++) + { + for (int zn = 0; zn < UNROLL_SG_N; zn++) + { + for (int i = 0; i < coopmat_K; i++) + { + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + const int gki = k * coopmat_K + i; + + if (gni < N && gki < K) + { + *p++ = B_data[gni * K + gki]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int bn = 0; bn < blocks_n; bn++) + { + float* p = B_data_packed.row(bn); + + int k = 0; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + for (int wn = 0; wn < UNROLL_WG_N; wn++) + { + for (int zk = 0; zk < UNROLL_SG_K; zk++) + { + for (int zn = 0; zn < UNROLL_SG_N; zn++) + { + for (int i = 0; i < coopmat_K; i++) + { + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + const int gki = (k + zk) * coopmat_K + i; + + if (gni < N && gki < K) + { + *p++ = B_data[gki * N + gni]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + for (; k < kk; k++) + { + for (int wn = 0; wn < UNROLL_WG_N; wn++) + { + for (int zn = 0; zn < UNROLL_SG_N; zn++) + { + for (int i = 0; i < coopmat_K; i++) + { + for (int j = 0; j < coopmat_N; j++) + { + const int gni = ((bn * UNROLL_WG_N + wn) * UNROLL_SG_N + zn) * coopmat_N + j; + const int gki = k * coopmat_K + i; + + if (gni < N && gki < K) + { + *p++ = B_data[gki * N + gni]; + } + else + { + *p++ = 0.f; + } + } + } + } + } + } + } + } + } + + int outh = output_transpose ? constantN : constantM; + int out_elempack = outh ? (outh % 4 == 0 ? 4 : 1) : 0; + + std::vector specializations(18 + 9); specializations[0].f = alpha; specializations[1].f = beta; specializations[2].i = transA; @@ -115,16 +452,19 @@ int Gemm_vulkan::create_pipeline(const Option& opt) specializations[12].i = output_elempack; specializations[13].i = output_elemtype; specializations[14].i = output_transpose; - - specializations[15].u32 = coopmat_M; - specializations[16].u32 = coopmat_N; - specializations[17].u32 = coopmat_K; - specializations[18].u32 = coopmat_subgroup_size; - specializations[19].u32 = UNROLL_SG_M; - specializations[20].u32 = UNROLL_SG_N; - specializations[21].u32 = UNROLL_SG_K; - specializations[22].u32 = UNROLL_WG_M; - specializations[23].u32 = UNROLL_WG_N; + specializations[15].i = A_data_packed.elempack; + specializations[16].i = B_data_packed.elempack; + specializations[17].i = output_elempack ? output_elempack : out_elempack; + + specializations[18 + 0].u32 = coopmat_M; + specializations[18 + 1].u32 = coopmat_N; + specializations[18 + 2].u32 = coopmat_K; + specializations[18 + 3].u32 = coopmat_subgroup_size; + specializations[18 + 4].u32 = UNROLL_SG_M; + specializations[18 + 5].u32 = UNROLL_SG_N; + specializations[18 + 6].u32 = UNROLL_SG_K; + specializations[18 + 7].u32 = UNROLL_WG_M; + specializations[18 + 8].u32 = UNROLL_WG_N; pipeline_gemm = new Pipeline(vkdev); pipeline_gemm->set_subgroup_size(coopmat_subgroup_size); @@ -303,17 +643,14 @@ int Gemm_vulkan::upload_model(VkTransfer& cmd, const Option& opt) int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const { - const VkMat& A0 = constantA ? A_data_gpu : bottom_blobs[0]; - const VkMat& B0 = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1]; - - VkMat A; - VkMat B; - vkdev->convert_packing(A0, A, 1, cmd, opt); - vkdev->convert_packing(B0, B, 1, cmd, opt); + const VkMat& A = constantA ? A_data_gpu : bottom_blobs[0]; + const VkMat& B = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1]; - const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); - const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; - const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; + const int A_elempack = A.elempack; + const int B_elempack = B.elempack; + const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c * A_elempack : A.h * A_elempack); + const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c * A_elempack : A.h * A_elempack) : A.w; + const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c * B_elempack : B.h * B_elempack) : B.w; VkMat C; int broadcast_type_C = -1; @@ -380,47 +717,60 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vector bindings(4); - bindings[0] = top_blob; - bindings[1] = A; - bindings[2] = B; - bindings[3] = C; - - std::vector constants(10); - constants[0].i = M; - constants[1].i = N; - constants[2].i = K; - constants[3].i = broadcast_type_C; - constants[4].i = A.dims; - constants[5].i = A.dims == 3 ? A.cstep : transA ? M : K; - constants[6].i = B.dims; - constants[7].i = B.dims == 3 ? B.cstep : transB ? K : N; - constants[8].i = top_blob.dims; - constants[9].i = top_blob.dims == 3 ? top_blob.cstep : top_blob.w; - if (use_cooperative_matrix) { + std::vector bindings(5); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + bindings[3] = C; + bindings[4] = top_blob; + + std::vector constants(13); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = broadcast_type_C; + constants[4].i = A.dims; + constants[5].i = A.dims == 3 ? A.cstep : A.dims == 2 ? A.w : transA ? M : K; + constants[6].i = B.dims; + constants[7].i = B.dims == 3 ? B.cstep : B.dims == 2 ? B.w : transB ? K : N; + constants[8].i = top_blob.dims; + constants[9].i = top_blob.dims == 3 ? top_blob.cstep : top_blob.w; + constants[10].i = A_elempack; + constants[11].i = B_elempack; + constants[12].i = out_elempack; + const int blocks_x = (M + coopmat_M * UNROLL_SG_M * UNROLL_WG_M - 1) / (coopmat_M * UNROLL_SG_M * UNROLL_WG_M); const int blocks_y = (N + coopmat_N * UNROLL_SG_N * UNROLL_WG_N - 1) / (coopmat_N * UNROLL_SG_N * UNROLL_WG_N); @@ -431,49 +781,63 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vectorinfo.subgroup_size(); - - const int blocks_x = (M + (UNROLL_SG_M * 4 - 1)) / (UNROLL_SG_M * 4); - const int blocks_y = (N + (UNROLL_SG_N * 4 - 1)) / (UNROLL_SG_N * 4); - - VkMat dispatcher; - dispatcher.w = (blocks_x * blocks_y) * subgroup_size; - dispatcher.h = 1; - dispatcher.c = 1; - cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher); - } else { - VkMat dispatcher; - dispatcher.w = (N + 3) / 4; - dispatcher.h = (M + 3) / 4; - dispatcher.c = 1; - cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher); - } + std::vector bindings(5); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + bindings[3] = C; + bindings[4] = top_blob; + + std::vector constants(13); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = broadcast_type_C; + constants[4].i = A.dims; + constants[5].i = A.dims == 3 ? A.cstep : A.dims == 2 ? A.w : transA ? M : K; + constants[6].i = B.dims; + constants[7].i = B.dims == 3 ? B.cstep : B.dims == 2 ? B.w : transB ? K : N; + constants[8].i = top_blob.dims; + constants[9].i = top_blob.dims == 3 ? top_blob.cstep : top_blob.w; + constants[10].i = out_elempack; + constants[11].i = A_elempack; + constants[12].i = B_elempack; + + if (opt.use_shader_local_memory) + { + VkMat dispatcher; + dispatcher.w = (N + 3) / 4; + dispatcher.h = (M + 3) / 4; + dispatcher.c = 1; + cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher); + } + else if (use_subgroup_ops) + { + bindings.resize(7); + bindings[5] = A; + bindings[6] = B; - int out_elempack = 1; - { - int outh = output_transpose ? N : M; - out_elempack = outh % 4 == 0 ? 4 : 1; - } - if (output_elempack) - out_elempack = output_elempack; + const int subgroup_size = vkdev->info.subgroup_size(); - if (out_elempack != 1) - { - VkMat top_blob0; - vkdev->convert_packing(top_blob, top_blob0, out_elempack, cmd, opt); - top_blobs[0] = top_blob0; + const int blocks_x = (M + (UNROLL_SG_M * 4 - 1)) / (UNROLL_SG_M * 4); + const int blocks_y = (N + (UNROLL_SG_N * 4 - 1)) / (UNROLL_SG_N * 4); + + VkMat dispatcher; + dispatcher.w = (blocks_x * blocks_y) * subgroup_size; + dispatcher.h = 1; + dispatcher.c = 1; + cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher); + } + else + { + VkMat dispatcher; + dispatcher.w = (N + 3) / 4; + dispatcher.h = (M + 3) / 4; + dispatcher.c = 1; + cmd.record_pipeline(pipeline_gemm, bindings, constants, dispatcher); + } } return 0; diff --git a/src/layer/vulkan/shader/gemm.comp b/src/layer/vulkan/shader/gemm.comp index e5498437f42..2251ecd351a 100644 --- a/src/layer/vulkan/shader/gemm.comp +++ b/src/layer/vulkan/shader/gemm.comp @@ -27,6 +27,7 @@ layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; layout(binding = 1) readonly buffer A_blob { sfpvec4 A_blob_data[]; }; layout(binding = 2) readonly buffer B_blob { sfpvec4 B_blob_data[]; }; layout(binding = 3) readonly buffer C_blob { sfp C_blob_data[]; }; +layout(binding = 4) writeonly buffer top_blob_4 { sfpvec4 top_blob_data_4[]; }; layout(push_constant) uniform parameter { @@ -40,6 +41,9 @@ layout(push_constant) uniform parameter int B_hstep; int outdims; int outhstep; + int out_elempack; + int A_elempack; + int B_elempack; } p; #if NCNN_shader_local_memory @@ -136,16 +140,58 @@ void main() for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH) { { - if (transA == 1) + if (p.A_elempack == 4) { - if (p.A_hstep % 4 == 0) + if (transA == 1) + { + // A is (M, K), K packed + const uint kd4 = (k + lx) / 4; + const uint km4 = (k + lx) % 4; + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + + afpvec4 a; + a.r = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.r)[km4]; + a.g = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.g)[km4]; + a.b = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.b)[km4]; + a.a = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.a)[km4]; + + tmp_a[ly][lx] = afp2lfpvec4(a); + } + else + { + // A is (K, M), M packed. Single vec4 load. + tmp_a[ly][lx] = buffer_sm4(A_blob_data, gy * p.A_hstep + (k + lx)); + } + } + else + { + if (transA == 1) { - const uint ai = (k + lx) * (p.A_hstep / 4) + gy; - tmp_a[ly][lx] = buffer_sm4(A_blob_data, ai); + if (p.A_hstep % 4 == 0) + { + const uint ai = (k + lx) * (p.A_hstep / 4) + gy; + tmp_a[ly][lx] = buffer_sm4(A_blob_data, ai); + } + else + { + const uvec4 ai4 = (k + lx) * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4 = ai4 % 4; + + afpvec4 a; + a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; + a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; + a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; + a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; + + tmp_a[ly][lx] = afp2lfpvec4(a); + } } else { - const uvec4 ai4 = (k + lx) * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 ai4 = gy4 * p.A_hstep + (k + lx); const uvec4 ai4d4 = ai4 / 4; const uvec4 ai4m4 = ai4 % 4; @@ -159,49 +205,36 @@ void main() tmp_a[ly][lx] = afp2lfpvec4(a); } } - else - { - const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); - const uvec4 ai4 = gy4 * p.A_hstep + (k + lx); - - const uvec4 ai4d4 = ai4 / 4; - const uvec4 ai4m4 = ai4 % 4; - - afpvec4 a; - a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; - a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; - a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; - a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; - - tmp_a[ly][lx] = afp2lfpvec4(a); - } - if (transB == 1) + if (p.B_elempack == 4) { - const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gx4 * p.B_hstep + (k + ly); - - const uvec4 bi4d4 = bi4 / 4; - const uvec4 bi4m4 = bi4 % 4; + if (transB == 0) + { + // B is (N, K), K packed + const uint kd4 = (k + ly) / 4; + const uint km4 = (k + ly) % 4; + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - afpvec4 b; - b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; - b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; - b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; - b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; + afpvec4 b; + b.r = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.r)[km4]; + b.g = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.g)[km4]; + b.b = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.b)[km4]; + b.a = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.a)[km4]; - tmp_b[lx][ly] = afp2lfpvec4(b); + tmp_b[lx][ly] = afp2lfpvec4(b); + } + else + { + // B is (K, N), N packed. Single vec4 load. + tmp_b[lx][ly] = buffer_sm4(B_blob_data, gx * p.B_hstep + (k + ly)); + } } else { - if (p.B_hstep % 4 == 0) - { - const uint bi = (k + ly) * (p.B_hstep / 4) + gx; - tmp_b[lx][ly] = buffer_sm4(B_blob_data, bi); - } - else + if (transB == 1) { - const uvec4 bi4 = (k + ly) * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 bi4 = gx4 * p.B_hstep + (k + ly); const uvec4 bi4d4 = bi4 / 4; const uvec4 bi4m4 = bi4 % 4; @@ -214,6 +247,29 @@ void main() tmp_b[lx][ly] = afp2lfpvec4(b); } + else + { + if (p.B_hstep % 4 == 0) + { + const uint bi = (k + ly) * (p.B_hstep / 4) + gx; + tmp_b[lx][ly] = buffer_sm4(B_blob_data, bi); + } + else + { + const uvec4 bi4 = (k + ly) * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4 = bi4 % 4; + + afpvec4 b; + b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; + b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; + b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; + b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; + + tmp_b[lx][ly] = afp2lfpvec4(b); + } + } } } @@ -240,16 +296,56 @@ void main() if (lx < remain) { - if (transA == 1) + if (p.A_elempack == 4) { - if (p.A_hstep % 4 == 0) + if (transA == 1) + { + const uint kd4 = (k + lx) / 4; + const uint km4 = (k + lx) % 4; + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + + afpvec4 a; + a.r = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.r)[km4]; + a.g = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.g)[km4]; + a.b = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.b)[km4]; + a.a = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.a)[km4]; + + tmp_a[ly][lx] = afp2lfpvec4(a); + } + else + { + tmp_a[ly][lx] = buffer_sm4(A_blob_data, gy * p.A_hstep + (k + lx)); + } + } + else + { + if (transA == 1) { - const uint ai = (k + lx) * (p.A_hstep / 4) + gy; - tmp_a[ly][lx] = buffer_sm4(A_blob_data, ai); + if (p.A_hstep % 4 == 0) + { + const uint ai = (k + lx) * (p.A_hstep / 4) + gy; + tmp_a[ly][lx] = buffer_sm4(A_blob_data, ai); + } + else + { + const uvec4 ai4 = (k + lx) * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4 = ai4 % 4; + + afpvec4 a; + a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; + a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; + a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; + a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; + + tmp_a[ly][lx] = afp2lfpvec4(a); + } } else { - const uvec4 ai4 = (k + lx) * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 ai4 = gy4 * p.A_hstep + (k + lx); const uvec4 ai4d4 = ai4 / 4; const uvec4 ai4m4 = ai4 % 4; @@ -263,52 +359,37 @@ void main() tmp_a[ly][lx] = afp2lfpvec4(a); } } - else - { - const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); - const uvec4 ai4 = gy4 * p.A_hstep + (k + lx); - - const uvec4 ai4d4 = ai4 / 4; - const uvec4 ai4m4 = ai4 % 4; - - afpvec4 a; - a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; - a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; - a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; - a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; - - tmp_a[ly][lx] = afp2lfpvec4(a); - } } if (ly < remain) { - if (transB == 1) + if (p.B_elempack == 4) { - const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gx4 * p.B_hstep + (k + ly); - - const uvec4 bi4d4 = bi4 / 4; - const uvec4 bi4m4 = bi4 % 4; + if (transB == 0) + { + const uint kd4 = (k + ly) / 4; + const uint km4 = (k + ly) % 4; + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - afpvec4 b; - b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; - b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; - b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; - b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; + afpvec4 b; + b.r = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.r)[km4]; + b.g = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.g)[km4]; + b.b = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.b)[km4]; + b.a = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.a)[km4]; - tmp_b[lx][ly] = afp2lfpvec4(b); + tmp_b[lx][ly] = afp2lfpvec4(b); + } + else + { + tmp_b[lx][ly] = buffer_sm4(B_blob_data, gx * p.B_hstep + (k + ly)); + } } else { - if (p.B_hstep % 4 == 0) + if (transB == 1) { - const uint bi = (k + ly) * (p.B_hstep / 4) + gx; - tmp_b[lx][ly] = buffer_sm4(B_blob_data, bi); - } - else - { - const uvec4 bi4 = (k + ly) * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 bi4 = gx4 * p.B_hstep + (k + ly); const uvec4 bi4d4 = bi4 / 4; const uvec4 bi4m4 = bi4 % 4; @@ -321,6 +402,29 @@ void main() tmp_b[lx][ly] = afp2lfpvec4(b); } + else + { + if (p.B_hstep % 4 == 0) + { + const uint bi = (k + ly) * (p.B_hstep / 4) + gx; + tmp_b[lx][ly] = buffer_sm4(B_blob_data, bi); + } + else + { + const uvec4 bi4 = (k + ly) * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4 = bi4 % 4; + + afpvec4 b; + b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; + b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; + b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; + b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; + + tmp_b[lx][ly] = afp2lfpvec4(b); + } + } } } @@ -342,16 +446,53 @@ void main() for (int k = 0; k < psc(K); k++) { afpvec4 a; - if (transA == 1) + if (p.A_elempack == 4) + { + if (transA == 1) + { + // A is (M, K), K packed. k is K-index, gy is M-group. + // vec4 at (k/4, row_m): data[(k/4) * hstep + row_m], component k%4 + const uint kd4 = k / 4; + const uint km4 = k % 4; + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + a.r = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.r)[km4]; + a.g = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.g)[km4]; + a.b = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.b)[km4]; + a.a = buffer_ld4(A_blob_data, kd4 * p.A_hstep + gy4.a)[km4]; + } + else + { + // A is (K, M), M packed. gy is M-group. + // vec4 at (gy, k): data[gy * hstep + k] = vec4(A[4*gy+0..3][k]) + a = buffer_ld4(A_blob_data, gy * p.A_hstep + k); + } + } + else { - if (p.A_hstep % 4 == 0) + if (transA == 1) { - const uint ai = k * (p.A_hstep / 4) + gy; - a = buffer_ld4(A_blob_data, ai); + if (p.A_hstep % 4 == 0) + { + const uint ai = k * (p.A_hstep / 4) + gy; + a = buffer_ld4(A_blob_data, ai); + } + else + { + const uvec4 ai4 = k * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4 = ai4 % 4; + + a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; + a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; + a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; + a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; + } } else { - const uvec4 ai4 = k * p.A_hstep + gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 ai4 = gy4 * p.A_hstep + k; const uvec4 ai4d4 = ai4 / 4; const uvec4 ai4m4 = ai4 % 4; @@ -362,44 +503,33 @@ void main() a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; } } - else - { - const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); - const uvec4 ai4 = gy4 * p.A_hstep + k; - - const uvec4 ai4d4 = ai4 / 4; - const uvec4 ai4m4 = ai4 % 4; - - a.r = buffer_ld4(A_blob_data, ai4d4.r)[ai4m4.r]; - a.g = buffer_ld4(A_blob_data, ai4d4.g)[ai4m4.g]; - a.b = buffer_ld4(A_blob_data, ai4d4.b)[ai4m4.b]; - a.a = buffer_ld4(A_blob_data, ai4d4.a)[ai4m4.a]; - } afpvec4 b; - if (transB == 1) + if (p.B_elempack == 4) { - const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gx4 * p.B_hstep + k; - - const uvec4 bi4d4 = bi4 / 4; - const uvec4 bi4m4 = bi4 % 4; - - b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; - b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; - b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; - b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; - } - else - { - if (p.B_hstep % 4 == 0) + if (transB == 0) { - const uint bi = k * (p.B_hstep / 4) + gx; - b = buffer_ld4(B_blob_data, bi); + // B is (N, K), K packed. k is K-index, gx is N-group. + const uint kd4 = k / 4; + const uint km4 = k % 4; + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + b.r = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.r)[km4]; + b.g = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.g)[km4]; + b.b = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.b)[km4]; + b.a = buffer_ld4(B_blob_data, kd4 * p.B_hstep + gx4.a)[km4]; } else { - const uvec4 bi4 = k * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + // B is (K, N), N packed. gx is N-group. + b = buffer_ld4(B_blob_data, gx * p.B_hstep + k); + } + } + else + { + if (transB == 1) + { + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 bi4 = gx4 * p.B_hstep + k; const uvec4 bi4d4 = bi4 / 4; const uvec4 bi4m4 = bi4 % 4; @@ -409,6 +539,26 @@ void main() b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; } + else + { + if (p.B_hstep % 4 == 0) + { + const uint bi = k * (p.B_hstep / 4) + gx; + b = buffer_ld4(B_blob_data, bi); + } + else + { + const uvec4 bi4 = k * p.B_hstep + gx * 4 + uvec4(0, 1, 2, 3); + + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4 = bi4 % 4; + + b.r = buffer_ld4(B_blob_data, bi4d4.r)[bi4m4.r]; + b.g = buffer_ld4(B_blob_data, bi4d4.g)[bi4m4.g]; + b.b = buffer_ld4(B_blob_data, bi4d4.b)[bi4m4.b]; + b.a = buffer_ld4(B_blob_data, bi4d4.a)[bi4m4.a]; + } + } } sum0 += a.r * b; @@ -433,64 +583,94 @@ void main() if (output_transpose == 1) { - const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); - const uvec4 gi4 = gx4 * p.outhstep + gy * 4; - - buffer_st1(top_blob_data, gi4.r, sum0.r); - if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.r + 1, sum1.r); - if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.r + 2, sum2.r); - if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.r + 3, sum3.r); - if (gx4.g < psc(N)) + if (p.out_elempack == 4) { - buffer_st1(top_blob_data, gi4.g, sum0.g); - if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.g + 1, sum1.g); - if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.g + 2, sum2.g); - if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.g + 3, sum3.g); - } - if (gx4.b < psc(N)) - { - buffer_st1(top_blob_data, gi4.b, sum0.b); - if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.b + 1, sum1.b); - if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.b + 2, sum2.b); - if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.b + 3, sum3.b); + // transpose output, pack4 on N dimension + // sum_i = vec4(C[gy*4+i][gx*4+0..3]) + // store sum_i directly as vec4 + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + const uint gi = gx * p.outhstep + gy * 4; + + buffer_st4(top_blob_data_4, gi, sum0); + if (gy * 4 + 1 < psc(M)) buffer_st4(top_blob_data_4, gi + 1, sum1); + if (gy * 4 + 2 < psc(M)) buffer_st4(top_blob_data_4, gi + 2, sum2); + if (gy * 4 + 3 < psc(M)) buffer_st4(top_blob_data_4, gi + 3, sum3); } - if (gx4.a < psc(N)) + else { - buffer_st1(top_blob_data, gi4.a, sum0.a); - if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.a + 1, sum1.a); - if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.a + 2, sum2.a); - if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.a + 3, sum3.a); + const uvec4 gx4 = gx * 4 + uvec4(0, 1, 2, 3); + const uvec4 gi4 = gx4 * p.outhstep + gy * 4; + + buffer_st1(top_blob_data, gi4.r, sum0.r); + if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.r + 1, sum1.r); + if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.r + 2, sum2.r); + if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.r + 3, sum3.r); + if (gx4.g < psc(N)) + { + buffer_st1(top_blob_data, gi4.g, sum0.g); + if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.g + 1, sum1.g); + if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.g + 2, sum2.g); + if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.g + 3, sum3.g); + } + if (gx4.b < psc(N)) + { + buffer_st1(top_blob_data, gi4.b, sum0.b); + if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.b + 1, sum1.b); + if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.b + 2, sum2.b); + if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.b + 3, sum3.b); + } + if (gx4.a < psc(N)) + { + buffer_st1(top_blob_data, gi4.a, sum0.a); + if (gy * 4 + 1 < psc(M)) buffer_st1(top_blob_data, gi4.a + 1, sum1.a); + if (gy * 4 + 2 < psc(M)) buffer_st1(top_blob_data, gi4.a + 2, sum2.a); + if (gy * 4 + 3 < psc(M)) buffer_st1(top_blob_data, gi4.a + 3, sum3.a); + } } } else { - const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); - const uvec4 gi4 = gy4 * p.outhstep + gx * 4; - - buffer_st1(top_blob_data, gi4.r, sum0.r); - if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.r + 1, sum0.g); - if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.r + 2, sum0.b); - if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.r + 3, sum0.a); - if (gy4.g < psc(M)) + if (p.out_elempack == 4) { - buffer_st1(top_blob_data, gi4.g, sum1.r); - if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.g + 1, sum1.g); - if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.g + 2, sum1.b); - if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.g + 3, sum1.a); + // non-transpose output, pack4 on M dimension + // pack vec4(sum0[j], sum1[j], sum2[j], sum3[j]) for each column j + const uint gi = gy * p.outhstep + gx * 4; + + buffer_st4(top_blob_data_4, gi, afpvec4(sum0.r, sum1.r, sum2.r, sum3.r)); + if (gx * 4 + 1 < psc(N)) buffer_st4(top_blob_data_4, gi + 1, afpvec4(sum0.g, sum1.g, sum2.g, sum3.g)); + if (gx * 4 + 2 < psc(N)) buffer_st4(top_blob_data_4, gi + 2, afpvec4(sum0.b, sum1.b, sum2.b, sum3.b)); + if (gx * 4 + 3 < psc(N)) buffer_st4(top_blob_data_4, gi + 3, afpvec4(sum0.a, sum1.a, sum2.a, sum3.a)); } - if (gy4.b < psc(M)) - { - buffer_st1(top_blob_data, gi4.b, sum2.r); - if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.b + 1, sum2.g); - if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.b + 2, sum2.b); - if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.b + 3, sum2.a); - } - if (gy4.a < psc(M)) + else { - buffer_st1(top_blob_data, gi4.a, sum3.r); - if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.a + 1, sum3.g); - if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.a + 2, sum3.b); - if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.a + 3, sum3.a); + const uvec4 gy4 = gy * 4 + uvec4(0, 1, 2, 3); + const uvec4 gi4 = gy4 * p.outhstep + gx * 4; + + buffer_st1(top_blob_data, gi4.r, sum0.r); + if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.r + 1, sum0.g); + if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.r + 2, sum0.b); + if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.r + 3, sum0.a); + if (gy4.g < psc(M)) + { + buffer_st1(top_blob_data, gi4.g, sum1.r); + if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.g + 1, sum1.g); + if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.g + 2, sum1.b); + if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.g + 3, sum1.a); + } + if (gy4.b < psc(M)) + { + buffer_st1(top_blob_data, gi4.b, sum2.r); + if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.b + 1, sum2.g); + if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.b + 2, sum2.b); + if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.b + 3, sum2.a); + } + if (gy4.a < psc(M)) + { + buffer_st1(top_blob_data, gi4.a, sum3.r); + if (gx * 4 + 1 < psc(N)) buffer_st1(top_blob_data, gi4.a + 1, sum3.g); + if (gx * 4 + 2 < psc(N)) buffer_st1(top_blob_data, gi4.a + 2, sum3.b); + if (gx * 4 + 3 < psc(N)) buffer_st1(top_blob_data, gi4.a + 3, sum3.a); + } } } } diff --git a/src/layer/vulkan/shader/gemm_cm.comp b/src/layer/vulkan/shader/gemm_cm.comp index 58eb46fad46..8069dc647f2 100644 --- a/src/layer/vulkan/shader/gemm_cm.comp +++ b/src/layer/vulkan/shader/gemm_cm.comp @@ -31,23 +31,27 @@ layout(constant_id = 11) const int output_N1M = 0; layout(constant_id = 12) const int output_elempack = 0; layout(constant_id = 13) const int output_elemtype = 0; layout(constant_id = 14) const int output_transpose = 0; - -layout(constant_id = 15) const uint M = 1; -layout(constant_id = 16) const uint N = 1; -layout(constant_id = 17) const uint K = 1; -layout(constant_id = 18) const uint subgroup_size = 32; -layout(constant_id = 19) const uint UNROLL_SG_M = 2; -layout(constant_id = 20) const uint UNROLL_SG_N = 2; -layout(constant_id = 21) const uint UNROLL_SG_K = 2; -layout(constant_id = 22) const uint UNROLL_WG_M = 2; -layout(constant_id = 23) const uint UNROLL_WG_N = 2; +layout(constant_id = 15) const int A_elempack = 0; +layout(constant_id = 16) const int B_elempack = 0; +layout(constant_id = 17) const int out_elempack = 0; + +layout(constant_id = 18 + 0) const uint M = 1; +layout(constant_id = 18 + 1) const uint N = 1; +layout(constant_id = 18 + 2) const uint K = 1; +layout(constant_id = 18 + 3) const uint subgroup_size = 32; +layout(constant_id = 18 + 4) const uint UNROLL_SG_M = 2; +layout(constant_id = 18 + 5) const uint UNROLL_SG_N = 2; +layout(constant_id = 18 + 6) const uint UNROLL_SG_K = 2; +layout(constant_id = 18 + 7) const uint UNROLL_WG_M = 2; +layout(constant_id = 18 + 8) const uint UNROLL_WG_N = 2; // TODO psc more layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; -layout(binding = 1) readonly buffer A_blob { uvec4 A_blob_data[]; }; -layout(binding = 2) readonly buffer B_blob { uvec4 B_blob_data[]; }; +layout(binding = 1) readonly buffer A_blob { uvec2 A_blob_data[]; }; +layout(binding = 2) readonly buffer B_blob { uvec2 B_blob_data[]; }; layout(binding = 3) readonly buffer C_blob { sfp C_blob_data[]; }; +layout(binding = 4) writeonly buffer top_blob_4 { uvec2 top_blob_data_4[]; }; layout(push_constant) uniform parameter { @@ -61,28 +65,32 @@ layout(push_constant) uniform parameter int B_hstep; int outdims; int outhstep; + int A_elempack; + int B_elempack; + int out_elempack; } p; -const uint Md8 = M / 8; -const uint Nd8 = N / 8; -const uint Kd8 = K / 8; +const uint Md4 = M / 4; +const uint Nd4 = N / 4; +const uint Kd4 = K / 4; // avoid bank conflict +#if ncnn_VK_KHR_cooperative_matrix #define PAD 1 +#elif ncnn_VK_NV_cooperative_matrix +// fixme: pad causes incorrect result on old driver +#define PAD 0 +#endif -const uint Md8p = Md8 + PAD; -const uint Nd8p = Nd8 + PAD; -const uint Kd8p = Kd8 + PAD; - -const uint tmp_a_size = UNROLL_WG_M * UNROLL_SG_K * UNROLL_SG_M * (transA == 0 ? M * Kd8p : K * Md8p); -const uint tmp_b_size = UNROLL_WG_N * UNROLL_SG_K * UNROLL_SG_N * (transB == 0 ? K * Nd8p : N * Kd8p); -const uint tmp_o_size = UNROLL_WG_N * UNROLL_WG_M * UNROLL_SG_N * UNROLL_SG_M * (output_transpose == 0 ? M * Nd8p : (M * Nd8p > N * Md8p ? M * Nd8p : N * Md8p)); +const uint Md4p = Md4 + PAD; +const uint Nd4p = Nd4 + PAD; +const uint Kd4p = Kd4 + PAD; // cannot alias output with a and b // cm store may happen while another subgroup is loading -shared uvec4 tmp_a[tmp_a_size]; -shared uvec4 tmp_b[tmp_b_size]; -shared uvec4 tmp_o[tmp_o_size]; +shared uvec2 tmp_a[UNROLL_WG_M][UNROLL_SG_K * UNROLL_SG_M * (M * Kd4p > K * Md4p ? M * Kd4p : K * Md4p)]; +shared uvec2 tmp_b[UNROLL_WG_N][UNROLL_SG_K * UNROLL_SG_N * (K * Nd4p > N * Kd4p ? K * Nd4p : N * Kd4p)]; +shared uvec2 tmp_o[UNROLL_WG_N * UNROLL_WG_M][UNROLL_SG_N * UNROLL_SG_M * M * N / 4]; void main() { @@ -228,49 +236,42 @@ void main() } if (broadcast_type_C == 3) { - const uint Nd8_M_USGM_USGN = Nd8 * M * UNROLL_SG_M * UNROLL_SG_N; - const uint Nd8_M_USGM_USGN_d_subgroupsize = (Nd8_M_USGM_USGN + subgroup_size - 1) / subgroup_size; - [[unroll]] for (uint q = 0; q < Nd8_M_USGM_USGN_d_subgroupsize; q++) + const uint Nd4_M_USGM_USGN = Nd4 * M * UNROLL_SG_M * UNROLL_SG_N; + const uint Nd4_M_USGM_USGN_d_subgroupsize = (Nd4_M_USGM_USGN + subgroup_size - 1) / subgroup_size; + [[unroll]] for (uint q = 0; q < Nd4_M_USGM_USGN_d_subgroupsize; q++) { const uint siq = si + q * subgroup_size; - if (Nd8_M_USGM_USGN % subgroup_size == 0 || siq < Nd8_M_USGM_USGN) + if (Nd4_M_USGM_USGN % subgroup_size == 0 || siq < Nd4_M_USGM_USGN) { - const uint zn = siq / (Nd8 * M * UNROLL_SG_M); - const uint zmij = siq % (Nd8 * M * UNROLL_SG_M); - const uint zm = zmij / (Nd8 * M); - const uint ij = zmij % (Nd8 * M); - const uint i = ij / Nd8; - const uint j = ij % Nd8; + const uint zn = siq / (Nd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Nd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Nd4 * M); + const uint ij = zmij % (Nd4 * M); + const uint i = ij / Nd4; + const uint j = ij % Nd4; const uint gm = (mi + zm) * M + i; - const uint gn = (ni + zn) * Nd8 + j; + const uint gn = (ni + zn) * Nd4 + j; if (gm < psc(GM)) { - const uvec4 ci4 = gm * psc(GN) + gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 ci8 = ci4 + 4; + const uvec4 ci4 = gm * psc(GN) + gn * 4 + uvec4(0, 1, 2, 3); vec2 va; vec2 vb; - vec2 vc; - vec2 vd; - if (gn * 8 < psc(GN)) va.r = float(buffer_ld1(C_blob_data, ci4.r)); - if (gn * 8 + 1 < psc(GN)) va.g = float(buffer_ld1(C_blob_data, ci4.g)); - if (gn * 8 + 2 < psc(GN)) vb.r = float(buffer_ld1(C_blob_data, ci4.b)); - if (gn * 8 + 3 < psc(GN)) vb.g = float(buffer_ld1(C_blob_data, ci4.a)); - if (gn * 8 + 4 < psc(GN)) vc.r = float(buffer_ld1(C_blob_data, ci8.r)); - if (gn * 8 + 5 < psc(GN)) vc.g = float(buffer_ld1(C_blob_data, ci8.g)); - if (gn * 8 + 6 < psc(GN)) vd.r = float(buffer_ld1(C_blob_data, ci8.b)); - if (gn * 8 + 7 < psc(GN)) vd.g = float(buffer_ld1(C_blob_data, ci8.a)); + if (gn * 4 < psc(GN)) va.r = float(buffer_ld1(C_blob_data, ci4.r)); + if (gn * 4 + 1 < psc(GN)) va.g = float(buffer_ld1(C_blob_data, ci4.g)); + if (gn * 4 + 2 < psc(GN)) vb.r = float(buffer_ld1(C_blob_data, ci4.b)); + if (gn * 4 + 3 < psc(GN)) vb.g = float(buffer_ld1(C_blob_data, ci4.a)); #if NCNN_bf16_storage || NCNN_bf16_packed - uvec4 v = uvec4(packBFloat2x16(va), packBFloat2x16(vb), packBFloat2x16(vc), packBFloat2x16(vd)); + uvec2 v = uvec2(packBFloat2x16(va), packBFloat2x16(vb)); #else - uvec4 v = uvec4(packHalf2x16(va), packHalf2x16(vb), packHalf2x16(vc), packHalf2x16(vd)); + uvec2 v = uvec2(packHalf2x16(va), packHalf2x16(vb)); #endif - tmp_o[(((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * M + i) * Nd8p + j] = v; + tmp_o[sgi][siq] = v; } } } @@ -283,22 +284,22 @@ void main() { #if ncnn_VK_KHR_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatLoad(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #else #if NCNN_bf16_storage || NCNN_bf16_packed coopmat sum_fp16; #else coopmat sum_fp16; #endif - coopMatLoad(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); sum[zn][zm] = coopmat(sum_fp16); #endif #elif ncnn_VK_NV_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatLoadNV(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, false); + coopMatLoadNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); #else fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16; - coopMatLoadNV(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, false); + coopMatLoadNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); sum[zn][zm] = fcoopmatNV<32, gl_ScopeSubgroup, M, N>(sum_fp16); #endif #endif @@ -364,516 +365,438 @@ void main() // local stack and shared memory ping-pong // prefetch - uvec4 prefetch_tmp_a[(UNROLL_SG_M * UNROLL_SG_K * M * K / 8 + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N)]; - uvec4 prefetch_tmp_b[(UNROLL_SG_N * UNROLL_SG_K * K * N / 8 + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M)]; + uvec2 prefetch_tmp_a[(UNROLL_SG_M * UNROLL_SG_K * M * K / 4 + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N)]; + uvec2 prefetch_tmp_b[(UNROLL_SG_N * UNROLL_SG_K * K * N / 4 + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M)]; // prefetch the very first { const uint ki = 0; // load A - if (transA == 0) + if (constantA == 1) { - // +-K-+ - // M | - // +- -+ - // SG_UM | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Kd8_M_USGM_USGK = Kd8 * M * UNROLL_SG_M * UNROLL_SG_K; - const uint Kd8_M_USGM_USGK_d_subgroupsize = (Kd8_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_USGK_d_subgroupsize; q++) + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_USGK = Kd4_M_USGM * UNROLL_SG_K; + const uint A_offset = (wgmi * kk * UNROLL_WG_M + sgmi * UNROLL_SG_K) * Kd4_M_USGM; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - if (Kd8_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM_USGK) + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) { - const uint zk = siq / (Kd8 * M * UNROLL_SG_M); - const uint zmij = siq % (Kd8 * M * UNROLL_SG_M); - const uint zm = zmij / (Kd8 * M); - const uint ij = zmij % (Kd8 * M); - const uint j = ij / Kd8; - const uint i = ij % Kd8; + prefetch_tmp_a[q] = A_blob_data[A_offset + siq]; + } + } + } + else if (transA == 0) + { + if (psc(A_elempack) == 1) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gm = (mi + zm) * M + j; + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gm < psc(GM)) + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; - if (p.A_hstep % 8 == 0) + if (gm < psc(GM)) { - const uint ai = gm * (p.A_hstep / 8) + gk; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - v = A_blob_data[ai]; + if (p.A_hstep % 4 == 0) + { + const uint ai = gm * (p.A_hstep / 4) + gk; - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + v = A_blob_data[ai]; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else - { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - const uvec4 ai4 = gm * p.A_hstep + gk4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - const uvec4 ai8 = gm * p.A_hstep + gk8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + const uvec4 ai4 = gm * p.A_hstep + gk4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; #if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; - - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } - } - prefetch_tmp_a[q] = v; + prefetch_tmp_a[q] = v; + } } } - } - else - { - // +-M-+ - // K | - // +SG_UM - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Md8_K_USGM_USGK = Md8 * K * UNROLL_SG_M * UNROLL_SG_K; - const uint Md8_K_USGM_USGK_d_subgroupsize = (Md8_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_USGK_d_subgroupsize; q++) + else // if (psc(A_elempack) == 4) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - - if (Md8_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM_USGK) + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Md8 * K * UNROLL_SG_M); - const uint zmij = siq % (Md8 * K * UNROLL_SG_M); - const uint zm = zmij / (Md8 * K); - const uint ij = zmij % (Md8 * K); - const uint i = ij / Md8; - const uint j = ij % Md8; - - const uint gk = ki + zk * K + i; - const uint gm = (mi + zm) * Md8 + j; + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) { - if (p.A_hstep % 8 == 0) - { - const uint ai = gk * (p.A_hstep / 8) + gm; - - if (gm * 8 < psc(GM)) v = A_blob_data[ai]; - } - else - { - const uvec4 gm4 = gm * 8 + uvec4(0, 1, 2, 3); - const uvec4 gm8 = gm4 + 4; - - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 ai4 = gk * p.A_hstep + gm4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; - const uvec4 ai8 = gk * p.A_hstep + gm8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; - -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gm4.r < psc(GM)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + uvec2 v = uvec2(0); - if (gm8.r < psc(GM)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gm4.r < psc(GM)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gk < psc(GK)) + { + v = A_blob_data[gm * p.A_hstep + gk]; - if (gm8.r < psc(GM)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif + v = v & packed_mask; } - } - prefetch_tmp_a[q] = v; + prefetch_tmp_a[q] = v; + } } } } - - // load B - if (transB == 0) + else { - // +-N-+ - // K | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Nd8_K_USGN_USGK = Nd8 * K * UNROLL_SG_N * UNROLL_SG_K; - const uint Nd8_K_USGN_USGK_d_subgroupsize = (Nd8_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_USGK_d_subgroupsize; q++) + if (psc(A_elempack) == 1) { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ - if (Nd8_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN_USGK) + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Nd8 * K * UNROLL_SG_N); - const uint znij = siq % (Nd8 * K * UNROLL_SG_N); - const uint zn = znij / (Nd8 * K); - const uint ij = znij % (Nd8 * K); - const uint i = ij / Nd8; - const uint j = ij % Nd8; - - const uint gk = ki + zk * K + i; - const uint gn = (ni + zn) * Nd8 + j; + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) { - const uvec4 gn4 = gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 gn8 = gn4 + 4; + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; - if (p.B_hstep % 8 == 0) - { - const uint bi = gk * (p.B_hstep / 8) + gn; + uvec2 v = uvec2(0); - if (gn * 8 < psc(GN)) v = B_blob_data[bi]; - } - else + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gk * p.B_hstep + gn4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + if (p.A_hstep % 4 == 0) + { + const uint ai = gk * (p.A_hstep / 4) + gm; - const uvec4 bi8 = gk * p.B_hstep + gn8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + v = A_blob_data[ai]; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 ai4 = gk * p.A_hstep + gm4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; - if (gn8.r < psc(GN)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gn4.r < psc(GN)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + if (gm4.r < psc(GM)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - if (gn8.r < psc(GN)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; - - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } - } - prefetch_tmp_b[q] = v; + prefetch_tmp_a[q] = v; + } } } - } - else - { - // +-K-+ - // N | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Kd8_N_USGN_USGK = Kd8 * N * UNROLL_SG_N * UNROLL_SG_K; - const uint Kd8_N_USGN_USGK_d_subgroupsize = (Kd8_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_USGK_d_subgroupsize; q++) + else // if (psc(A_elempack) == 4) { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - - if (Kd8_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN_USGK) + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Kd8 * N * UNROLL_SG_N); - const uint znij = siq % (Kd8 * N * UNROLL_SG_N); - const uint zn = znij / (Kd8 * N); - const uint ij = znij % (Kd8 * N); - const uint j = ij / Kd8; - const uint i = ij % Kd8; - - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gn = (ni + zn) * N + j; + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gn < psc(GN)) + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; - - if (p.B_hstep % 8 == 0) - { - const uint bi = gn * (p.B_hstep / 8) + gk; + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; - v = B_blob_data[bi]; + uvec2 v = uvec2(0); - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else + if (gm < psc(GM)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 bi4 = gn * p.B_hstep + gk4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; - - const uvec4 bi8 = gn * p.B_hstep + gk8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; - -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + v = A_blob_data[gk * p.A_hstep + gm]; - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; - - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif + v = v & packed_mask; } - } - prefetch_tmp_b[q] = v; + prefetch_tmp_a[q] = v; + } } } } - } - - k += UNROLL_SG_K; - - for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) - { - barrier(); - // copy prefetch to shared memory + // load B + if (constantB == 1) { - // load A - if (transA == 0) + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4_K_USGN * UNROLL_SG_K; + const uint B_offset = (wgni * kk * UNROLL_WG_N + sgni * UNROLL_SG_K) * Nd4_K_USGN; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { - const uint Kd8_M_USGM_USGK = Kd8 * M * UNROLL_SG_M * UNROLL_SG_K; - const uint Kd8_M_USGM_USGK_d_subgroupsize = (Kd8_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_USGK_d_subgroupsize; q++) - { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - - if (Kd8_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM_USGK) - { - const uint j = siq / Kd8; - const uint i = siq % Kd8; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - tmp_a[(sgmi * UNROLL_SG_K * UNROLL_SG_M * M + j) * Kd8p + i] = prefetch_tmp_a[q]; - } + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + prefetch_tmp_b[q] = B_blob_data[B_offset + siq]; } } - else + } + else if (transB == 0) + { + if (psc(B_elempack) == 1) { - const uint Md8_K_USGM_USGK = Md8 * K * UNROLL_SG_M * UNROLL_SG_K; - const uint Md8_K_USGM_USGK_d_subgroupsize = (Md8_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_USGK_d_subgroupsize; q++) + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Md8_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint i = siq / Md8; - const uint j = siq % Md8; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; - tmp_a[(sgmi * UNROLL_SG_K * UNROLL_SG_M * K + i) * Md8p + j] = prefetch_tmp_a[q]; - } - } - } + uvec2 v = uvec2(0); - // load B - if (transB == 0) - { - const uint Nd8_K_USGN_USGK = Nd8 * K * UNROLL_SG_N * UNROLL_SG_K; - const uint Nd8_K_USGN_USGK_d_subgroupsize = (Nd8_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_USGK_d_subgroupsize; q++) - { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + const uint gk = ki + zk * K + i; + const uint gn = (ni + zn) * Nd4 + j; - if (Nd8_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN_USGK) - { - const uint i = siq / Nd8; - const uint j = siq % Nd8; + if (gk < psc(GK)) + { + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + + if (p.B_hstep % 4 == 0) + { + const uint bi = gk * (p.B_hstep / 4) + gn; + + v = B_blob_data[bi]; + + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 bi4 = gk * p.B_hstep + gn4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gn4.r < psc(GN)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } - tmp_b[(sgni * UNROLL_SG_K * UNROLL_SG_N * K + i) * Nd8p + j] = prefetch_tmp_b[q]; + prefetch_tmp_b[q] = v; } } } - else + else // if (psc(B_elempack) == 4) { - const uint Kd8_N_USGN_USGK = Kd8 * N * UNROLL_SG_N * UNROLL_SG_K; - const uint Kd8_N_USGN_USGK_d_subgroupsize = (Kd8_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_USGK_d_subgroupsize; q++) + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Kd8_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN_USGK) + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) { - const uint j = siq / Kd8; - const uint i = siq % Kd8; + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + + uvec2 v = uvec2(0); + + const uint i = ij / Kd4; + const uint j = ij % Kd4; + + const uint gn = (ni + zn) * N + i; + const uint gk = ki / 4 + zk * Kd4 + j; - tmp_b[(sgni * UNROLL_SG_K * UNROLL_SG_N * N + j) * Kd8p + i] = prefetch_tmp_b[q]; + if (gn < psc(GN)) + { + v = B_blob_data[gk * p.B_hstep + gn]; + + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + prefetch_tmp_b[q] = v; } } } } - - barrier(); - - // prefetch the next + else { - const uint ki = k * K; - - // load A - if (transA == 0) + if (psc(B_elempack) == 1) { // +-K-+ - // M | - // +- -+ - // SG_UM | + // N | + // +SG_UN + // | | // ^ +---+ // | | | // SG_UK+- -+ @@ -882,395 +805,828 @@ void main() // | | | // | +- -+ // | | | - // WG_UM +---+ + // WG_UN +---+ // | | | // | +- -+ // | | | // v +---+ - const uint Kd8_M_USGM_USGK = Kd8 * M * UNROLL_SG_M * UNROLL_SG_K; - const uint Kd8_M_USGM_USGK_d_subgroupsize = (Kd8_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_USGK_d_subgroupsize; q++) + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Kd8_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM_USGK) + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) { - const uint zk = siq / (Kd8 * M * UNROLL_SG_M); - const uint zmij = siq % (Kd8 * M * UNROLL_SG_M); - const uint zm = zmij / (Kd8 * M); - const uint ij = zmij % (Kd8 * M); - const uint j = ij / Kd8; - const uint i = ij % Kd8; - - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gm = (mi + zm) * M + j; + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + const uint j = ij / Kd4; + const uint i = ij % Kd4; - uvec4 v = uvec4(0); - if (gm < psc(GM)) + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gn = (ni + zn) * N + j; + + if (gn < psc(GN)) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - if (p.A_hstep % 8 == 0) + if (p.B_hstep % 4 == 0) { - const uint ai = gm * (p.A_hstep / 8) + gk; + const uint bi = gn * (p.B_hstep / 4) + gk; - v = A_blob_data[ai]; + v = B_blob_data[bi]; uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; + v = v & packed_mask; } else { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 ai4 = gm * p.A_hstep + gk4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - const uvec4 ai8 = gm * p.A_hstep + gk8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + const uvec4 bi4 = gn * p.B_hstep + gk4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; #if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; - - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif } } - prefetch_tmp_a[q] = v; + prefetch_tmp_b[q] = v; } } } - else + else // if (psc(B_elempack) == 4) { - // +-M-+ - // K | - // +SG_UM - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Md8_K_USGM_USGK = Md8 * K * UNROLL_SG_M * UNROLL_SG_K; - const uint Md8_K_USGM_USGK_d_subgroupsize = (Md8_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_USGK_d_subgroupsize; q++) + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Md8_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint zk = siq / (Md8 * K * UNROLL_SG_M); - const uint zmij = siq % (Md8 * K * UNROLL_SG_M); - const uint zm = zmij / (Md8 * K); - const uint ij = zmij % (Md8 * K); - const uint i = ij / Md8; - const uint j = ij % Md8; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint j = ij / Nd4; + const uint i = ij % Nd4; - const uint gk = ki + zk * K + i; - const uint gm = (mi + zm) * Md8 + j; + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + j; + const uint gn = (ni + zn) * Nd4 + i; - uvec4 v = uvec4(0); if (gk < psc(GK)) { - if (p.A_hstep % 8 == 0) - { - const uint ai = gk * (p.A_hstep / 8) + gm; + v = B_blob_data[gn * p.B_hstep + gk]; - if (gm * 8 < psc(GM)) v = A_blob_data[ai]; - } - else - { - const uvec4 gm4 = gm * 8 + uvec4(0, 1, 2, 3); - const uvec4 gm8 = gm4 + 4; + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + v = v & packed_mask; + } - const uvec4 ai4 = gk * p.A_hstep + gm4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + prefetch_tmp_b[q] = v; + } + } + } + } + } - const uvec4 ai8 = gk * p.A_hstep + gm8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + k += UNROLL_SG_K; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gm4.r < psc(GM)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + for (; k + UNROLL_SG_K - 1 < kk; k += UNROLL_SG_K) + { + barrier(); - if (gm8.r < psc(GM)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + // copy prefetch to shared memory + { + // load A + if (constantA == 1) + { + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gm4.r < psc(GM)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * M + j) * Kd4p + i] = prefetch_tmp_a[q]; + } + } + } + else if (transA == 0) + { + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - if (gm8.r < psc(GM)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + if (psc(A_elempack) == 4) + { + const uint i = siq / Md4; + const uint j = siq % Md4; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif - } + tmp_a[sgmi][i * Md4p + j] = prefetch_tmp_a[q]; } + else + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; - prefetch_tmp_a[q] = v; + tmp_a[sgmi][j * Kd4p + i] = prefetch_tmp_a[q]; + } + } + } + } + else + { + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + if (psc(A_elempack) == 4) + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; + + tmp_a[sgmi][j * Kd4p + i] = prefetch_tmp_a[q]; + } + else + { + const uint i = siq / Md4; + const uint j = siq % Md4; + + tmp_a[sgmi][i * Md4p + j] = prefetch_tmp_a[q]; + } } } } // load B - if (transB == 0) + if (constantB == 1) { - // +-N-+ - // K | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - const uint Nd8_K_USGN_USGK = Nd8 * K * UNROLL_SG_N * UNROLL_SG_K; - const uint Nd8_K_USGN_USGK_d_subgroupsize = (Nd8_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_USGK_d_subgroupsize; q++) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; + + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * K + i) * Nd4p + j] = prefetch_tmp_b[q]; + } + } + } + else if (transB == 0) + { + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Nd8_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint zk = siq / (Nd8 * K * UNROLL_SG_N); - const uint znij = siq % (Nd8 * K * UNROLL_SG_N); - const uint zn = znij / (Nd8 * K); - const uint ij = znij % (Nd8 * K); - const uint i = ij / Nd8; - const uint j = ij % Nd8; + if (psc(B_elempack) == 4) + { + const uint i = siq / Kd4; + const uint j = siq % Kd4; - const uint gk = ki + zk * K + i; - const uint gn = (ni + zn) * Nd8 + j; + tmp_b[sgni][i * Kd4p + j] = prefetch_tmp_b[q]; + } + else + { + const uint i = siq / Nd4; + const uint j = siq % Nd4; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + tmp_b[sgni][i * Nd4p + j] = prefetch_tmp_b[q]; + } + } + } + } + else + { + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) + { + if (psc(B_elempack) == 4) { - const uvec4 gn4 = gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 gn8 = gn4 + 4; + const uint j = siq / Nd4; + const uint i = siq % Nd4; - if (p.B_hstep % 8 == 0) - { - const uint bi = gk * (p.B_hstep / 8) + gn; + tmp_b[sgni][j * Nd4p + i] = prefetch_tmp_b[q]; + } + else + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; - if (gn * 8 < psc(GN)) v = B_blob_data[bi]; - } - else + tmp_b[sgni][j * Kd4p + i] = prefetch_tmp_b[q]; + } + } + } + } + } + + barrier(); + + // prefetch the next + { + const uint ki = k * K; + + // load A + if (constantA == 1) + { + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_USGK = Kd4_M_USGM * UNROLL_SG_K; + const uint A_offset = (wgmi * kk * UNROLL_WG_M + k * UNROLL_WG_M + sgmi * UNROLL_SG_K) * Kd4_M_USGM; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + prefetch_tmp_a[q] = A_blob_data[A_offset + siq]; + } + } + } + else if (transA == 0) + { + if (psc(A_elempack) == 1) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; + + if (gm < psc(GM)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gk * p.B_hstep + gn4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + if (p.A_hstep % 4 == 0) + { + const uint ai = gm * (p.A_hstep / 4) + gk; - const uvec4 bi8 = gk * p.B_hstep + gn8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + v = A_blob_data[ai]; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - if (gn8.r < psc(GN)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gn4.r < psc(GN)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + const uvec4 ai4 = gm * p.A_hstep + gk4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - if (gn8.r < psc(GN)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } + + prefetch_tmp_a[q] = v; } + } + } + else // if (psc(A_elempack) == 4) + { + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - prefetch_tmp_b[q] = v; + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) + { + v = A_blob_data[gm * p.A_hstep + gk]; + + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + prefetch_tmp_a[q] = v; + } } } } else { - // +-K-+ - // N | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ + if (psc(A_elempack) == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) + { + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + + if (p.A_hstep % 4 == 0) + { + const uint ai = gk * (p.A_hstep / 4) + gm; + + v = A_blob_data[ai]; + + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 ai4 = gk * p.A_hstep + gm4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - const uint Kd8_N_USGN_USGK = Kd8 * N * UNROLL_SG_N * UNROLL_SG_K; - const uint Kd8_N_USGN_USGK_d_subgroupsize = (Kd8_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_USGK_d_subgroupsize; q++) + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gm4.r < psc(GM)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } + + prefetch_tmp_a[q] = v; + } + } + } + else // if (psc(A_elempack) == 4) + { + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; + + if (gm < psc(GM)) + { + v = A_blob_data[gk * p.A_hstep + gm]; + + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + prefetch_tmp_a[q] = v; + } + } + } + } + + // load B + if (constantB == 1) + { + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4_K_USGN * UNROLL_SG_K; + const uint B_offset = (wgni * kk * UNROLL_WG_N + k * UNROLL_WG_N + sgni * UNROLL_SG_K) * Nd4_K_USGN; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Kd8_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint zk = siq / (Kd8 * N * UNROLL_SG_N); - const uint znij = siq % (Kd8 * N * UNROLL_SG_N); - const uint zn = znij / (Kd8 * N); - const uint ij = znij % (Kd8 * N); - const uint j = ij / Kd8; - const uint i = ij % Kd8; - - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gn = (ni + zn) * N + j; + prefetch_tmp_b[q] = B_blob_data[B_offset + siq]; + } + } + } + else if (transB == 0) + { + if (psc(B_elempack) == 1) + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gn < psc(GN)) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; + + uvec2 v = uvec2(0); - if (p.B_hstep % 8 == 0) + const uint gk = ki + zk * K + i; + const uint gn = (ni + zn) * Nd4 + j; + + if (gk < psc(GK)) { - const uint bi = gn * (p.B_hstep / 8) + gk; + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); - v = B_blob_data[bi]; + if (p.B_hstep % 4 == 0) + { + const uint bi = gk * (p.B_hstep / 4) + gn; + + v = B_blob_data[bi]; + + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 bi4 = gk * p.B_hstep + gn4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gn4.r < psc(GN)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } + + prefetch_tmp_b[q] = v; + } + } + } + else // if (psc(B_elempack) == 4) + { + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) + { + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + const uint i = ij / Kd4; + const uint j = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gn = (ni + zn) * N + i; + const uint gk = ki / 4 + zk * Kd4 + j; + + if (gn < psc(GN)) + { + v = B_blob_data[gk * p.B_hstep + gn]; + + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; + v = v & packed_mask; } - else + + prefetch_tmp_b[q] = v; + } + } + } + } + else + { + if (psc(B_elempack) == 1) + { + // +-K-+ + // N | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) + { + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gn = (ni + zn) * N + j; + + if (gn < psc(GN)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gn * p.B_hstep + gk4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + if (p.B_hstep % 4 == 0) + { + const uint bi = gn * (p.B_hstep / 4) + gk; - const uvec4 bi8 = gn * p.B_hstep + gk8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + v = B_blob_data[bi]; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + const uvec4 bi4 = gn * p.B_hstep + gk4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } + + prefetch_tmp_b[q] = v; } + } + } + else // if (psc(B_elempack) == 4) + { + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - prefetch_tmp_b[q] = v; + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint j = ij / Nd4; + const uint i = ij % Nd4; + + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + j; + const uint gn = (ni + zn) * Nd4 + i; + + if (gk < psc(GK)) + { + v = B_blob_data[gn * p.B_hstep + gk]; + + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + prefetch_tmp_b[q] = v; + } } } } @@ -1293,41 +1649,101 @@ void main() { [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - if (transA == 0) + if (constantA == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else if (transA == 0) { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, false); + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); #endif + } } else { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); +#endif + } + } + } + + [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) + { + if (constantB == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); +#endif + } + else if (transB == 0) + { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, true); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); #endif + } } - } - - [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) - { - if (transB == 0) + else { + if (psc(B_elempack) == 4) + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, false); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); #endif - } - else - { + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, true); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); #endif + } } } @@ -1350,73 +1766,155 @@ void main() // the last copy prefetch to shared memory { - if (transA == 0) + if (constantA == 1) { - const uint Kd8_M_USGM_USGK = Kd8 * M * UNROLL_SG_M * UNROLL_SG_K; - const uint Kd8_M_USGM_USGK_d_subgroupsize = (Kd8_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_USGK_d_subgroupsize; q++) + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - if (Kd8_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM_USGK) + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) { - const uint j = siq / Kd8; - const uint i = siq % Kd8; + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * M + j) * Kd4p + i] = prefetch_tmp_a[q]; + } + } + } + else if (transA == 0) + { + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + if (psc(A_elempack) == 4) + { + const uint i = siq / Md4; + const uint j = siq % Md4; + + tmp_a[sgmi][i * Md4p + j] = prefetch_tmp_a[q]; + } + else + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; - tmp_a[(sgmi * UNROLL_SG_K * UNROLL_SG_M * M + j) * Kd8p + i] = prefetch_tmp_a[q]; + tmp_a[sgmi][j * Kd4p + i] = prefetch_tmp_a[q]; + } } } } else { - const uint Md8_K_USGM_USGK = Md8 * K * UNROLL_SG_M * UNROLL_SG_K; - const uint Md8_K_USGM_USGK_d_subgroupsize = (Md8_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_USGK_d_subgroupsize; q++) + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - if (Md8_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM_USGK) + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) { - const uint i = siq / Md8; - const uint j = siq % Md8; + if (psc(A_elempack) == 4) + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; + + tmp_a[sgmi][j * Kd4p + i] = prefetch_tmp_a[q]; + } + else + { + const uint i = siq / Md4; + const uint j = siq % Md4; - tmp_a[(sgmi * UNROLL_SG_K * UNROLL_SG_M * K + i) * Md8p + j] = prefetch_tmp_a[q]; + tmp_a[sgmi][i * Md4p + j] = prefetch_tmp_a[q]; + } } } } // load B - if (transB == 0) + if (constantB == 1) + { + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) + { + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; + + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * K + i) * Nd4p + j] = prefetch_tmp_b[q]; + } + } + } + else if (transB == 0) { - const uint Nd8_K_USGN_USGK = Nd8 * K * UNROLL_SG_N * UNROLL_SG_K; - const uint Nd8_K_USGN_USGK_d_subgroupsize = (Nd8_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_USGK_d_subgroupsize; q++) + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Nd8_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint i = siq / Nd8; - const uint j = siq % Nd8; + if (psc(B_elempack) == 4) + { + const uint i = siq / Kd4; + const uint j = siq % Kd4; + + tmp_b[sgni][i * Kd4p + j] = prefetch_tmp_b[q]; + } + else + { + const uint i = siq / Nd4; + const uint j = siq % Nd4; - tmp_b[(sgni * UNROLL_SG_K * UNROLL_SG_N * K + i) * Nd8p + j] = prefetch_tmp_b[q]; + tmp_b[sgni][i * Nd4p + j] = prefetch_tmp_b[q]; + } } } } else { - const uint Kd8_N_USGN_USGK = Kd8 * N * UNROLL_SG_N * UNROLL_SG_K; - const uint Kd8_N_USGN_USGK_d_subgroupsize = (Kd8_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_USGK_d_subgroupsize; q++) + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Kd8_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN_USGK) + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) { - const uint j = siq / Kd8; - const uint i = siq % Kd8; + if (psc(B_elempack) == 4) + { + const uint j = siq / Nd4; + const uint i = siq % Nd4; + + tmp_b[sgni][j * Nd4p + i] = prefetch_tmp_b[q]; + } + else + { + const uint j = siq / Kd4; + const uint i = siq % Kd4; - tmp_b[(sgni * UNROLL_SG_K * UNROLL_SG_N * N + j) * Kd8p + i] = prefetch_tmp_b[q]; + tmp_b[sgni][j * Kd4p + i] = prefetch_tmp_b[q]; + } } } } @@ -1441,41 +1939,101 @@ void main() { [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - if (transA == 0) + if (constantA == 1) { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else if (transA == 0) + { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); +#endif + } + else + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, false); + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); #endif + } } else { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, true); + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); #endif + } } } [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) { - if (transB == 0) + if (constantB == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); +#endif + } + else if (transB == 0) { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, false); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); #endif + } } else { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, true); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); #endif + } } } @@ -1500,409 +2058,559 @@ void main() const uint ki = 0; // load A - if (transA == 0) + if (constantA == 1) + { + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_USGK = Kd4_M_USGM * UNROLL_SG_K; + const uint A_offset = (wgmi * kk * UNROLL_WG_M + sgmi * UNROLL_SG_K) * Kd4_M_USGM; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * M + j) * Kd4p + i] = A_blob_data[A_offset + siq]; + } + } + } + else if (transA == 0) + { + if (psc(A_elempack) == 1) + { + // +-K-+ + // M | + // +- -+ + // SG_UM | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; + + if (gm < psc(GM)) + { + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + + if (p.A_hstep % 4 == 0) + { + const uint ai = gm * (p.A_hstep / 4) + gk; + + v = A_blob_data[ai]; + + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 ai4 = gm * p.A_hstep + gk4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * M + j) * Kd4p + i] = v; + } + } + } + else // if (psc(A_elempack) == 4) + { + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) + { + v = A_blob_data[gm * p.A_hstep + gk]; + + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * K + i) * Md4p + j] = v; + } + } + } + } + else + { + if (psc(A_elempack) == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Md4_K_USGM_USGK = Md4 * K * UNROLL_SG_M * UNROLL_SG_K; + const uint Md4_K_USGM_USGK_d_subgroupsize = (Md4_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Md4_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM_USGK) + { + const uint zk = siq / (Md4 * K * UNROLL_SG_M); + const uint zmij = siq % (Md4 * K * UNROLL_SG_M); + const uint zm = zmij / (Md4 * K); + const uint ij = zmij % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) + { + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + + if (p.A_hstep % 4 == 0) + { + const uint ai = gk * (p.A_hstep / 4) + gm; + + v = A_blob_data[ai]; + + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 ai4 = gk * p.A_hstep + gm4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gm4.r < psc(GM)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * K + i) * Md4p + j] = v; + } + } + } + else // if (psc(A_elempack) == 4) + { + const uint Kd4_M_USGM_USGK = Kd4 * M * UNROLL_SG_M * UNROLL_SG_K; + const uint Kd4_M_USGM_USGK_d_subgroupsize = (Kd4_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM_USGK) + { + const uint zk = siq / (Kd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Kd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Kd4 * M); + const uint ij = zmij % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gm = (mi + zm) * M + j; + + if (gm < psc(GM)) + { + v = A_blob_data[gk * p.A_hstep + gm]; + + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + tmp_a[sgmi][((zk * UNROLL_SG_M + zm) * M + j) * Kd4p + i] = v; + } + } + } + } + + // load B + if (constantB == 1) { - // +-K-+ - // M | - // +- -+ - // SG_UM | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Kd8_M_USGM_USGK = Kd8 * M * UNROLL_SG_M * UNROLL_SG_K; - const uint Kd8_M_USGM_USGK_d_subgroupsize = (Kd8_M_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_USGK_d_subgroupsize; q++) + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_USGK = Nd4_K_USGN * UNROLL_SG_K; + const uint B_offset = (wgni * kk * UNROLL_WG_N + sgni * UNROLL_SG_K) * Nd4_K_USGN; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Kd8_M_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM_USGK) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uint zk = siq / (Kd8 * M * UNROLL_SG_M); - const uint zmij = siq % (Kd8 * M * UNROLL_SG_M); - const uint zm = zmij / (Kd8 * M); - const uint ij = zmij % (Kd8 * M); - const uint j = ij / Kd8; - const uint i = ij % Kd8; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; + + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * K + i) * Nd4p + j] = B_blob_data[B_offset + siq]; + } + } + } + else if (transB == 0) + { + if (psc(B_elempack) == 1) + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gm = (mi + zm) * M + j; + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gm < psc(GM)) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; - if (p.A_hstep % 8 == 0) + uvec2 v = uvec2(0); + + const uint gk = ki + zk * K + i; + const uint gn = (ni + zn) * Nd4 + j; + + if (gk < psc(GK)) { - const uint ai = gm * (p.A_hstep / 8) + gk; + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); - v = A_blob_data[ai]; + if (p.B_hstep % 4 == 0) + { + const uint bi = gk * (p.B_hstep / 4) + gn; - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + v = B_blob_data[bi]; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else - { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - const uvec4 ai4 = gm * p.A_hstep + gk4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - const uvec4 ai8 = gm * p.A_hstep + gk8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + const uvec4 bi4 = gk * p.B_hstep + gn4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; #if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gn4.r < psc(GN)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; - - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } - } - tmp_a[(((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * M + j) * Kd8p + i] = v; + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * K + i) * Nd4p + j] = v; + } } } - } - else - { - // +-M-+ - // K | - // +SG_UM - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Md8_K_USGM_USGK = Md8 * K * UNROLL_SG_M * UNROLL_SG_K; - const uint Md8_K_USGM_USGK_d_subgroupsize = (Md8_K_USGM_USGK + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_USGK_d_subgroupsize; q++) + else // if (psc(B_elempack) == 4) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - - if (Md8_K_USGM_USGK % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM_USGK) + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Md8 * K * UNROLL_SG_M); - const uint zmij = siq % (Md8 * K * UNROLL_SG_M); - const uint zm = zmij / (Md8 * K); - const uint ij = zmij % (Md8 * K); - const uint i = ij / Md8; - const uint j = ij % Md8; - - const uint gk = ki + zk * K + i; - const uint gm = (mi + zm) * Md8 + j; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) { - if (p.A_hstep % 8 == 0) - { - const uint ai = gk * (p.A_hstep / 8) + gm; + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + const uint i = ij / Kd4; + const uint j = ij % Kd4; - if (gm * 8 < psc(GM)) v = A_blob_data[ai]; - } - else - { - const uvec4 gm4 = gm * 8 + uvec4(0, 1, 2, 3); - const uvec4 gm8 = gm4 + 4; - - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 ai4 = gk * p.A_hstep + gm4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; - - const uvec4 ai8 = gk * p.A_hstep + gm8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; - -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gm4.r < psc(GM)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + uvec2 v = uvec2(0); - if (gm8.r < psc(GM)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uint gn = (ni + zn) * N + i; + const uint gk = ki / 4 + zk * Kd4 + j; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gm4.r < psc(GM)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gn < psc(GN)) + { + v = B_blob_data[gk * p.B_hstep + gn]; - if (gm8.r < psc(GM)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif + v = v & packed_mask; } - } - tmp_a[(((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * K + i) * Md8p + j] = v; + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * N + i) * Kd4p + j] = v; + } } } } - - // load B - if (transB == 0) + else { - // +-N-+ - // K | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Nd8_K_USGN_USGK = Nd8 * K * UNROLL_SG_N * UNROLL_SG_K; - const uint Nd8_K_USGN_USGK_d_subgroupsize = (Nd8_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_USGK_d_subgroupsize; q++) + if (psc(B_elempack) == 1) { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + // +-K-+ + // N | + // +SG_UN + // | | + // ^ +---+ + // | | | + // SG_UK+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // v +---+ - if (Nd8_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN_USGK) + const uint Kd4_N_USGN_USGK = Kd4 * N * UNROLL_SG_N * UNROLL_SG_K; + const uint Kd4_N_USGN_USGK_d_subgroupsize = (Kd4_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Nd8 * K * UNROLL_SG_N); - const uint znij = siq % (Nd8 * K * UNROLL_SG_N); - const uint zn = znij / (Nd8 * K); - const uint ij = znij % (Nd8 * K); - const uint i = ij / Nd8; - const uint j = ij % Nd8; - - const uint gk = ki + zk * K + i; - const uint gn = (ni + zn) * Nd8 + j; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Kd4_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN_USGK) { - const uvec4 gn4 = gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 gn8 = gn4 + 4; + const uint zk = siq / (Kd4 * N * UNROLL_SG_N); + const uint znij = siq % (Kd4 * N * UNROLL_SG_N); + const uint zn = znij / (Kd4 * N); + const uint ij = znij % (Kd4 * N); + const uint j = ij / Kd4; + const uint i = ij % Kd4; - if (p.B_hstep % 8 == 0) - { - const uint bi = gk * (p.B_hstep / 8) + gn; + uvec2 v = uvec2(0); - if (gn * 8 < psc(GN)) v = B_blob_data[bi]; - } - else + const uint gk = ki / 4 + zk * Kd4 + i; + const uint gn = (ni + zn) * N + j; + + if (gn < psc(GN)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = gk * p.B_hstep + gn4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + if (p.B_hstep % 4 == 0) + { + const uint bi = gn * (p.B_hstep / 4) + gk; - const uvec4 bi8 = gk * p.B_hstep + gn8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + v = B_blob_data[bi]; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - if (gn8.r < psc(GN)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gn4.r < psc(GN)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + const uvec4 bi4 = gn * p.B_hstep + gk4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - if (gn8.r < psc(GN)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } - } - tmp_b[(((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * K + i) * Nd8p + j] = v; + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * N + j) * Kd4p + i] = v; + } } } - } - else - { - // +-K-+ - // N | - // +SG_UN - // | | - // ^ +---+ - // | | | - // SG_UK+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Kd8_N_USGN_USGK = Kd8 * N * UNROLL_SG_N * UNROLL_SG_K; - const uint Kd8_N_USGN_USGK_d_subgroupsize = (Kd8_N_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_USGK_d_subgroupsize; q++) + else // if (psc(B_elempack) == 4) { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - - if (Kd8_N_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN_USGK) + const uint Nd4_K_USGN_USGK = Nd4 * K * UNROLL_SG_N * UNROLL_SG_K; + const uint Nd4_K_USGN_USGK_d_subgroupsize = (Nd4_K_USGN_USGK + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_USGK_d_subgroupsize; q++) { - const uint zk = siq / (Kd8 * N * UNROLL_SG_N); - const uint znij = siq % (Kd8 * N * UNROLL_SG_N); - const uint zn = znij / (Kd8 * N); - const uint ij = znij % (Kd8 * N); - const uint j = ij / Kd8; - const uint i = ij % Kd8; - - const uint gk = ki / 8 + zk * Kd8 + i; - const uint gn = (ni + zn) * N + j; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gn < psc(GN)) + if (Nd4_K_USGN_USGK % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN_USGK) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zk = siq / (Nd4 * K * UNROLL_SG_N); + const uint znij = siq % (Nd4 * K * UNROLL_SG_N); + const uint zn = znij / (Nd4 * K); + const uint ij = znij % (Nd4 * K); + const uint j = ij / Nd4; + const uint i = ij % Nd4; - if (p.B_hstep % 8 == 0) - { - const uint bi = gn * (p.B_hstep / 8) + gk; - - v = B_blob_data[bi]; + uvec2 v = uvec2(0); - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + const uint gk = ki + zk * K + j; + const uint gn = (ni + zn) * Nd4 + i; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else + if (gk < psc(GK)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 bi4 = gn * p.B_hstep + gk4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; - - const uvec4 bi8 = gn * p.B_hstep + gk8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; - -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; - - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + v = B_blob_data[gn * p.B_hstep + gk]; - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif + v = v & packed_mask; } - } - tmp_b[(((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * N + j) * Kd8p + i] = v; + tmp_b[sgni][((zk * UNROLL_SG_N + zn) * K + j) * Nd4p + i] = v; + } } } } @@ -1926,41 +2634,101 @@ void main() { [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - if (transA == 0) + if (constantA == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else if (transA == 0) { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, false); + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); #endif + } } else { + if (psc(A_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Kd4p * M), Kd4p, false); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, ((sgmi * UNROLL_SG_K + zk) * UNROLL_SG_M + zm) * (Md8p * K), Md8p, true); + coopMatLoadNV(A[zm], tmp_a[sgmi], (zk * UNROLL_SG_M + zm) * (Md4p * K), Md4p, true); #endif + } } } [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) { - if (transB == 0) + if (constantB == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); +#endif + } + else if (transB == 0) { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, false); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); #endif + } } else { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Nd4p * K), Nd4p, false); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, ((sgni * UNROLL_SG_K + zk) * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, true); + coopMatLoadNV(B[zn], tmp_b[sgni], (zk * UNROLL_SG_N + zn) * (Kd4p * N), Kd4p, true); #endif + } } } @@ -1988,192 +2756,259 @@ void main() barrier(); // load A - if (transA == 0) + if (constantA == 1) { - // +-K-+ - // M | - // +SG_UM - // | | - // ^ +---+ - // | | | - // WG_UM+- -+ - // | | | - // v +---+ - - const uint Kd8_M_USGM = Kd8 * M * UNROLL_SG_M; - const uint Kd8_M_USGM_d_subgroupsize = (Kd8_M_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Kd8_M_USGM_d_subgroupsize; q++) + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint A_offset = ((wgmi * kk + k) * UNROLL_WG_M + sgmi) * Kd4_M_USGM; + const uint Kd4_M_USGM_d_subgroupsize = (Kd4_M_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - if (Kd8_M_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd8_M_USGM) + if (Kd4_M_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM) + { + const uint zm = siq / (Kd4 * M); + const uint ij = siq % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + tmp_a[sgmi][(zm * M + j) * Kd4p + i] = A_blob_data[A_offset + siq]; + } + } + } + else if (transA == 0) + { + if (psc(A_elempack) == 1) + { + // +-K-+ + // M | + // +SG_UM + // | | + // ^ +---+ + // | | | + // WG_UM+- -+ + // | | | + // v +---+ + + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_d_subgroupsize = (Kd4_M_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Kd4_M_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM) + { + const uint zm = siq / (Kd4 * M); + const uint ij = siq % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + i; + const uint gm = (mi + zm) * M + j; + + if (gm < psc(GM)) + { + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + + if (p.A_hstep % 4 == 0) + { + const uint ai = gm * (p.A_hstep / 4) + gk; + + v = A_blob_data[ai]; + + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); + + const uvec4 ai4 = gm * p.A_hstep + gk4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; + +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; + + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); +#endif + } + } + + tmp_a[sgmi][(zm * M + j) * Kd4p + i] = v; + } + } + } + else // if (psc(A_elempack) == 4) + { + const uint Md4_K_USGM = Md4 * K * UNROLL_SG_M; + const uint Md4_K_USGM_d_subgroupsize = (Md4_K_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_d_subgroupsize; q++) { - const uint zm = siq / (Kd8 * M); - const uint ij = siq % (Kd8 * M); - const uint j = ij / Kd8; - const uint i = ij % Kd8; + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - const uint gk = ki / 8 + i; - const uint gm = (mi + zm) * M + j; + if (Md4_K_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM) + { + const uint zm = siq / (Md4 * K); + const uint ij = siq % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; - uvec4 v = uvec4(0); + uvec2 v = uvec2(0); - if (gm < psc(GM)) + const uint gk = ki + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) + { + v = A_blob_data[gm * p.A_hstep + gk]; + + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + tmp_a[sgmi][(zm * K + i) * Md4p + j] = v; + } + } + } + } + else + { + if (psc(A_elempack) == 1) + { + // +-M-+ + // K | + // +SG_UM + // | | + // ^ +---+ + // | | | + // WG_UM+- -+ + // | | | + // v +---+ + + const uint Md4_K_USGM = Md4 * K * UNROLL_SG_M; + const uint Md4_K_USGM_d_subgroupsize = (Md4_K_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Md4_K_USGM_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; + + if (Md4_K_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md4_K_USGM) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zm = siq / (Md4 * K); + const uint ij = siq % (Md4 * K); + const uint i = ij / Md4; + const uint j = ij % Md4; + + uvec2 v = uvec2(0); - if (p.A_hstep % 8 == 0) + const uint gk = ki + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gk < psc(GK)) { - const uint ai = gm * (p.A_hstep / 8) + gk; + const uvec4 gm4 = gm * 4 + uvec4(0, 1, 2, 3); - v = A_blob_data[ai]; + if (p.A_hstep % 4 == 0) + { + const uint ai = gk * (p.A_hstep / 4) + gm; - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + v = A_blob_data[ai]; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else - { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + uvec4 mask4 = uvec4(lessThan(gm4, uvec4(psc(GM)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - const uvec4 ai4 = gm * p.A_hstep + gk4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - const uvec4 ai8 = gm * p.A_hstep + gk8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + const uvec4 ai4 = gk * p.A_hstep + gm4; + const uvec4 ai4d4 = ai4 / 4; + const uvec4 ai4m4d2 = (ai4 % 4) / 2; + const uvec4 ai4m2 = ai4 % 2; #if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackBFloat2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gm4.r < psc(GM)) v4a.r = unpackHalf2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; + if (gm4.g < psc(GM)) v4a.g = unpackHalf2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; + if (gm4.b < psc(GM)) v4b.r = unpackHalf2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; + if (gm4.a < psc(GM)) v4b.g = unpackHalf2x16(A_blob_data[ai4d4.a][ai4m4d2.a])[ai4m2.a]; - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; - - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } - } - tmp_a[((sgmi * UNROLL_SG_M + zm) * M + j) * Kd8p + i] = v; + tmp_a[sgmi][(zm * K + i) * Md4p + j] = v; + } } } - } - else - { - // +-M-+ - // K | - // +SG_UM - // | | - // ^ +---+ - // | | | - // WG_UM+- -+ - // | | | - // v +---+ - - const uint Md8_K_USGM = Md8 * K * UNROLL_SG_M; - const uint Md8_K_USGM_d_subgroupsize = (Md8_K_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); - [[unroll]] for (uint q = 0; q < Md8_K_USGM_d_subgroupsize; q++) + else // if (psc(A_elempack) == 4) { - const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - - if (Md8_K_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Md8_K_USGM) + const uint Kd4_M_USGM = Kd4 * M * UNROLL_SG_M; + const uint Kd4_M_USGM_d_subgroupsize = (Kd4_M_USGM + (subgroup_size * UNROLL_WG_N - 1)) / (subgroup_size * UNROLL_WG_N); + [[unroll]] for (uint q = 0; q < Kd4_M_USGM_d_subgroupsize; q++) { - const uint zm = siq / (Md8 * K); - const uint ij = siq % (Md8 * K); - const uint i = ij / Md8; - const uint j = ij % Md8; - - const uint gk = ki + i; - const uint gm = (mi + zm) * Md8 + j; + const uint siq = (q * UNROLL_WG_N + sgni) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Kd4_M_USGM % (subgroup_size * UNROLL_WG_N) == 0 || siq < Kd4_M_USGM) { - if (p.A_hstep % 8 == 0) - { - const uint ai = gk * (p.A_hstep / 8) + gm; - - if (gm * 8 < psc(GM)) v = A_blob_data[ai]; - } - else - { - const uvec4 gm4 = gm * 8 + uvec4(0, 1, 2, 3); - const uvec4 gm8 = gm4 + 4; - - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); - - const uvec4 ai4 = gk * p.A_hstep + gm4; - const uvec4 ai4d8 = ai4 / 8; - const uvec4 ai4m8d2 = (ai4 % 8) / 2; - const uvec4 ai4m2 = ai4 % 2; - - const uvec4 ai8 = gk * p.A_hstep + gm8; - const uvec4 ai8d8 = ai8 / 8; - const uvec4 ai8m8d2 = (ai8 % 8) / 2; - const uvec4 ai8m2 = ai8 % 2; + const uint zm = siq / (Kd4 * M); + const uint ij = siq % (Kd4 * M); + const uint j = ij / Kd4; + const uint i = ij % Kd4; -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gm4.r < psc(GM)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackBFloat2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + uvec2 v = uvec2(0); - if (gm8.r < psc(GM)) v8.r = unpackBFloat2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackBFloat2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackBFloat2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackBFloat2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uint gk = ki / 4 + i; + const uint gm = (mi + zm) * M + j; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gm4.r < psc(GM)) v4.r = unpackHalf2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; - if (gm4.g < psc(GM)) v4.g = unpackHalf2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; - if (gm4.b < psc(GM)) v4.b = unpackHalf2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; - if (gm4.a < psc(GM)) v4.a = unpackHalf2x16(A_blob_data[ai4d8.a][ai4m8d2.a])[ai4m2.a]; + if (gm < psc(GM)) + { + v = A_blob_data[gk * p.A_hstep + gm]; - if (gm8.r < psc(GM)) v8.r = unpackHalf2x16(A_blob_data[ai8d8.r][ai8m8d2.r])[ai8m2.r]; - if (gm8.g < psc(GM)) v8.g = unpackHalf2x16(A_blob_data[ai8d8.g][ai8m8d2.g])[ai8m2.g]; - if (gm8.b < psc(GM)) v8.b = unpackHalf2x16(A_blob_data[ai8d8.b][ai8m8d2.b])[ai8m2.b]; - if (gm8.a < psc(GM)) v8.a = unpackHalf2x16(A_blob_data[ai8d8.a][ai8m8d2.a])[ai8m2.a]; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); -#endif + v = v & packed_mask; } - } - tmp_a[((sgmi * UNROLL_SG_M + zm) * K + i) * Md8p + j] = v; + tmp_a[sgmi][(zm * M + j) * Kd4p + i] = v; + } } } } // load B - if (transB == 0) + if (constantB == 1) { // +-N-+ // K | @@ -2185,173 +3020,253 @@ void main() // | | | // v +---+ - const uint Nd8_K_USGN = Nd8 * K * UNROLL_SG_N; - const uint Nd8_K_USGN_d_subgroupsize = (Nd8_K_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Nd8_K_USGN_d_subgroupsize; q++) + // B_data coopmat_N * coopmat_K * UNROLL_SG_N * UNROLL_WG_N * kk, blocks_n + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint B_offset = ((wgni * kk + k) * UNROLL_WG_N + sgni) * Nd4_K_USGN; + const uint Nd4_K_USGN_d_subgroupsize = (Nd4_K_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_d_subgroupsize; q++) { const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - if (Nd8_K_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd8_K_USGN) + if (Nd4_K_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN) { - const uint zn = siq / (Nd8 * K); - const uint ij = siq % (Nd8 * K); - const uint i = ij / Nd8; - const uint j = ij % Nd8; + const uint zn = siq / (Nd4 * K); + const uint ij = siq % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; - const uint gk = ki + i; - const uint gn = (ni + zn) * Nd8 + j; + tmp_b[sgni][(zn * K + i) * Nd4p + j] = B_blob_data[B_offset + siq]; + } + } + } + else if (transB == 0) + { + if (psc(B_elempack) == 1) + { + // +-N-+ + // K | + // +SG_UN + // | | + // ^ +---+ + // | | | + // WG_UN+- -+ + // | | | + // v +---+ + + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_d_subgroupsize = (Nd4_K_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gk < psc(GK)) + if (Nd4_K_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN) { - const uvec4 gn4 = gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 gn8 = gn4 + 4; + const uint zn = siq / (Nd4 * K); + const uint ij = siq % (Nd4 * K); + const uint i = ij / Nd4; + const uint j = ij % Nd4; - if (p.B_hstep % 8 == 0) - { - const uint bi = gk * (p.B_hstep / 8) + gn; + uvec2 v = uvec2(0); - if (gn * 8 < psc(GN)) v = B_blob_data[bi]; - } - else + const uint gk = ki + i; + const uint gn = (ni + zn) * Nd4 + j; + + if (gk < psc(GK)) { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + + if (p.B_hstep % 4 == 0) + { + const uint bi = gk * (p.B_hstep / 4) + gn; - const uvec4 bi4 = gk * p.B_hstep + gn4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + v = B_blob_data[bi]; - const uvec4 bi8 = gk * p.B_hstep + gn8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); -#if NCNN_bf16_storage || NCNN_bf16_packed - if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - if (gn8.r < psc(GN)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + const uvec4 bi4 = gk * p.B_hstep + gn4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); -#else - if (gn4.r < psc(GN)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gn4.g < psc(GN)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gn4.b < psc(GN)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gn4.a < psc(GN)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; +#if NCNN_bf16_storage || NCNN_bf16_packed + if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - if (gn8.r < psc(GN)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gn8.g < psc(GN)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gn8.b < psc(GN)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gn8.a < psc(GN)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); +#else + if (gn4.r < psc(GN)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gn4.g < psc(GN)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gn4.b < psc(GN)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gn4.a < psc(GN)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } + + tmp_b[sgni][(zn * K + i) * Nd4p + j] = v; } + } + } + else // if (psc(B_elempack) == 4) + { + const uint Kd4_N_USGN = Kd4 * N * UNROLL_SG_N; + const uint Kd4_N_USGN_d_subgroupsize = (Kd4_N_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + + if (Kd4_N_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN) + { + const uint zn = siq / (Kd4 * N); + const uint ij = siq % (Kd4 * N); + const uint i = ij / Kd4; + const uint j = ij % Kd4; + + uvec2 v = uvec2(0); - tmp_b[((sgni * UNROLL_SG_N + zn) * K + i) * Nd8p + j] = v; + const uint gn = (ni + zn) * N + i; + const uint gk = ki / 4 + j; + + if (gn < psc(GN)) + { + v = B_blob_data[gk * p.B_hstep + gn]; + + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } + + tmp_b[sgni][(zn * N + i) * Kd4p + j] = v; + } } } } else { - // +-K-+ - // N | - // +SG_UN - // | | - // ^ +---+ - // | | | - // WG_UN+- -+ - // | | | - // v +---+ - - const uint Kd8_N_USGN = Kd8 * N * UNROLL_SG_N; - const uint Kd8_N_USGN_d_subgroupsize = (Kd8_N_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); - [[unroll]] for (uint q = 0; q < Kd8_N_USGN_d_subgroupsize; q++) + if (psc(B_elempack) == 1) { - const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - - if (Kd8_N_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd8_N_USGN) + // +-K-+ + // N | + // +SG_UN + // | | + // ^ +---+ + // | | | + // WG_UN+- -+ + // | | | + // v +---+ + + const uint Kd4_N_USGN = Kd4 * N * UNROLL_SG_N; + const uint Kd4_N_USGN_d_subgroupsize = (Kd4_N_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Kd4_N_USGN_d_subgroupsize; q++) { - const uint zn = siq / (Kd8 * N); - const uint ij = siq % (Kd8 * N); - const uint j = ij / Kd8; - const uint i = ij % Kd8; - - const uint gk = ki / 8 + i; - const uint gn = (ni + zn) * N + j; + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; - uvec4 v = uvec4(0); - if (gn < psc(GN)) + if (Kd4_N_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Kd4_N_USGN) { - const uvec4 gk4 = gk * 8 + uvec4(0, 1, 2, 3); - const uvec4 gk8 = gk4 + 4; + const uint zn = siq / (Kd4 * N); + const uint ij = siq % (Kd4 * N); + const uint j = ij / Kd4; + const uint i = ij % Kd4; + + uvec2 v = uvec2(0); + + const uint gk = ki / 4 + i; + const uint gn = (ni + zn) * N + j; - if (p.B_hstep % 8 == 0) + if (gn < psc(GN)) { - const uint bi = gn * (p.B_hstep / 8) + gk; + const uvec4 gk4 = gk * 4 + uvec4(0, 1, 2, 3); - v = B_blob_data[bi]; + if (p.B_hstep % 4 == 0) + { + const uint bi = gn * (p.B_hstep / 4) + gk; - uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; - uvec4 mask8 = uvec4(lessThan(gk8, uvec4(psc(GK)))) * 0xFFFFu; - uvec2 packed_mask4 = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - uvec2 packed_mask8 = uvec2(mask8.x | (mask8.y << 16), mask8.z | (mask8.w << 16)); + v = B_blob_data[bi]; - v.rg = v.rg & packed_mask4; - v.ba = v.ba & packed_mask8; - } - else - { - vec4 v4 = vec4(0.f); - vec4 v8 = vec4(0.f); + uvec4 mask4 = uvec4(lessThan(gk4, uvec4(psc(GK)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); - const uvec4 bi4 = gn * p.B_hstep + gk4; - const uvec4 bi4d8 = bi4 / 8; - const uvec4 bi4m8d2 = (bi4 % 8) / 2; - const uvec4 bi4m2 = bi4 % 2; + v = v & packed_mask; + } + else + { + vec2 v4a = vec2(0.f); + vec2 v4b = vec2(0.f); - const uvec4 bi8 = gn * p.B_hstep + gk8; - const uvec4 bi8d8 = bi8 / 8; - const uvec4 bi8m8d2 = (bi8 % 8) / 2; - const uvec4 bi8m2 = bi8 % 2; + const uvec4 bi4 = gn * p.B_hstep + gk4; + const uvec4 bi4d4 = bi4 / 4; + const uvec4 bi4m4d2 = (bi4 % 4) / 2; + const uvec4 bi4m2 = bi4 % 2; #if NCNN_bf16_storage || NCNN_bf16_packed - if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackBFloat2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackBFloat2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - if (gk8.r < psc(GK)) v8.r = unpackBFloat2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackBFloat2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackBFloat2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackBFloat2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; - - v = uvec4(packBFloat2x16(v4.rg), packBFloat2x16(v4.ba), packBFloat2x16(v8.rg), packBFloat2x16(v8.ba)); + v = uvec2(packBFloat2x16(v4a), packBFloat2x16(v4b)); #else - if (gk4.r < psc(GK)) v4.r = unpackHalf2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; - if (gk4.g < psc(GK)) v4.g = unpackHalf2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; - if (gk4.b < psc(GK)) v4.b = unpackHalf2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; - if (gk4.a < psc(GK)) v4.a = unpackHalf2x16(B_blob_data[bi4d8.a][bi4m8d2.a])[bi4m2.a]; - - if (gk8.r < psc(GK)) v8.r = unpackHalf2x16(B_blob_data[bi8d8.r][bi8m8d2.r])[bi8m2.r]; - if (gk8.g < psc(GK)) v8.g = unpackHalf2x16(B_blob_data[bi8d8.g][bi8m8d2.g])[bi8m2.g]; - if (gk8.b < psc(GK)) v8.b = unpackHalf2x16(B_blob_data[bi8d8.b][bi8m8d2.b])[bi8m2.b]; - if (gk8.a < psc(GK)) v8.a = unpackHalf2x16(B_blob_data[bi8d8.a][bi8m8d2.a])[bi8m2.a]; + if (gk4.r < psc(GK)) v4a.r = unpackHalf2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; + if (gk4.g < psc(GK)) v4a.g = unpackHalf2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; + if (gk4.b < psc(GK)) v4b.r = unpackHalf2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; + if (gk4.a < psc(GK)) v4b.g = unpackHalf2x16(B_blob_data[bi4d4.a][bi4m4d2.a])[bi4m2.a]; - v = uvec4(packHalf2x16(v4.rg), packHalf2x16(v4.ba), packHalf2x16(v8.rg), packHalf2x16(v8.ba)); + v = uvec2(packHalf2x16(v4a), packHalf2x16(v4b)); #endif + } } + + tmp_b[sgni][(zn * N + j) * Kd4p + i] = v; } + } + } + else // if (psc(B_elempack) == 4) + { + const uint Nd4_K_USGN = Nd4 * K * UNROLL_SG_N; + const uint Nd4_K_USGN_d_subgroupsize = (Nd4_K_USGN + (subgroup_size * UNROLL_WG_M - 1)) / (subgroup_size * UNROLL_WG_M); + [[unroll]] for (uint q = 0; q < Nd4_K_USGN_d_subgroupsize; q++) + { + const uint siq = (q * UNROLL_WG_M + sgmi) * subgroup_size + si; + + if (Nd4_K_USGN % (subgroup_size * UNROLL_WG_M) == 0 || siq < Nd4_K_USGN) + { + const uint zn = siq / (Nd4 * K); + const uint ij = siq % (Nd4 * K); + const uint j = ij / Nd4; + const uint i = ij % Nd4; + + uvec2 v = uvec2(0); + + const uint gk = ki + j; + const uint gn = (ni + zn) * Nd4 + i; + + if (gk < psc(GK)) + { + v = B_blob_data[gn * p.B_hstep + gk]; + + const uvec4 gn4 = gn * 4 + uvec4(0, 1, 2, 3); + uvec4 mask4 = uvec4(lessThan(gn4, uvec4(psc(GN)))) * 0xFFFFu; + uvec2 packed_mask = uvec2(mask4.x | (mask4.y << 16), mask4.z | (mask4.w << 16)); + + v = v & packed_mask; + } - tmp_b[((sgni * UNROLL_SG_N + zn) * N + j) * Kd8p + i] = v; + tmp_b[sgni][(zn * K + j) * Nd4p + i] = v; + } } } } @@ -2373,41 +3288,101 @@ void main() [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - if (transA == 0) + if (constantA == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, false); +#endif + } + else if (transA == 0) { + if (psc(A_elempack) == 4) + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, (sgmi * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(A[zm], tmp_a[sgmi], zm * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, (sgmi * UNROLL_SG_M + zm) * (Kd8p * M), Kd8p, false); + coopMatLoadNV(A[zm], tmp_a[sgmi], zm * (Md4p * K), Md4p, true); #endif + } + else + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, false); +#endif + } } else { + if (psc(A_elempack) == 4) + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(A[zm], tmp_a, (sgmi * UNROLL_SG_M + zm) * (Md8p * K), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(A[zm], tmp_a, (sgmi * UNROLL_SG_M + zm) * (Md8p * K), Md8p, true); + coopMatLoadNV(A[zm], tmp_a[sgmi], zm * (Kd4p * M), Kd4p, false); #endif + } + else + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(A[zm], tmp_a[sgmi], zm * (Md4p * K), Md4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(A[zm], tmp_a[sgmi], zm * (Md4p * K), Md4p, true); +#endif + } } } [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) { - if (transB == 0) + if (constantB == 1) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, false); +#endif + } + else if (transB == 0) { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], zn * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], zn * (Kd4p * N), Kd4p, true); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, (sgni * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, (sgni * UNROLL_SG_N + zn) * (Nd8p * K), Nd8p, false); + coopMatLoadNV(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, false); #endif + } } else { + if (psc(B_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix + coopMatLoad(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatLoadNV(B[zn], tmp_b[sgni], zn * (Nd4p * K), Nd4p, false); +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix - coopMatLoad(B[zn], tmp_b, (sgni * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(B[zn], tmp_b[sgni], zn * (Kd4p * N), Kd4p, gl_CooperativeMatrixLayoutColumnMajor); #elif ncnn_VK_NV_cooperative_matrix - coopMatLoadNV(B[zn], tmp_b, (sgni * UNROLL_SG_N + zn) * (Kd8p * N), Kd8p, true); + coopMatLoadNV(B[zn], tmp_b[sgni], zn * (Kd4p * N), Kd4p, true); #endif + } } } @@ -2442,47 +3417,97 @@ void main() { if (output_transpose == 0) { + if (psc(out_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#else +#if NCNN_bf16_storage || NCNN_bf16_packed + coopmat sum_fp16 = coopmat(sum[zn][zm]); +#else + coopmat sum_fp16 = coopmat(sum[zn][zm]); +#endif + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); +#endif +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); +#else + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); +#endif +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatStore(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #else #if NCNN_bf16_storage || NCNN_bf16_packed - coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopmat sum_fp16 = coopmat(sum[zn][zm]); #else - coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopmat sum_fp16 = coopmat(sum[zn][zm]); #endif - coopMatStore(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #endif #elif ncnn_VK_NV_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatStoreNV(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, false); + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); #else - fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); - coopMatStoreNV(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, false); + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); #endif #endif + } } else { + if (psc(out_elempack) == 4) + { +#if ncnn_VK_KHR_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#else +#if NCNN_bf16_storage || NCNN_bf16_packed + coopmat sum_fp16 = coopmat(sum[zn][zm]); +#else + coopmat sum_fp16 = coopmat(sum[zn][zm]); +#endif + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); +#endif +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); +#else + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, false); +#endif +#endif + } + else + { #if ncnn_VK_KHR_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatStore(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Md8p * N), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); #else #if NCNN_bf16_storage || NCNN_bf16_packed - coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopmat sum_fp16 = coopmat(sum[zn][zm]); #else - coopmat sum_fp16 = coopmat(sum[zn][zm]); + coopmat sum_fp16 = coopmat(sum[zn][zm]); #endif - coopMatStore(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Md8p * N), Md8p, gl_CooperativeMatrixLayoutColumnMajor); + coopMatStore(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); #endif #elif ncnn_VK_NV_cooperative_matrix #if NCNN_fp16_arithmetic - coopMatStoreNV(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Md8p * N), Md8p, true); + coopMatStoreNV(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); #else - fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); - coopMatStoreNV(sum_fp16, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Md8p * N), Md8p, true); + fcoopmatNV<16, gl_ScopeSubgroup, M, N> sum_fp16 = fcoopmatNV<16, gl_ScopeSubgroup, M, N>(sum[zn][zm]); + coopMatStoreNV(sum_fp16, tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, true); #endif #endif + } } } } @@ -2492,162 +3517,216 @@ void main() // store top_blob if (output_transpose == 0) { - // +-N-+ - // M | - // +SG_UM - // | | - // ^ +---+ - // | | | - // SG_UN+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +- -+ - // | | | - // | +- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // | +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // | +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Nd8_M_USGM_USGN = Nd8 * M * UNROLL_SG_M * UNROLL_SG_N; - const uint Nd8_M_USGM_USGN_d_subgroupsize = (Nd8_M_USGM_USGN + subgroup_size - 1) / subgroup_size; - [[unroll]] for (uint q = 0; q < Nd8_M_USGM_USGN_d_subgroupsize; q++) + if (psc(out_elempack) == 4) { - const uint siq = si + q * subgroup_size; - - if (Nd8_M_USGM_USGN % subgroup_size == 0 || siq < Nd8_M_USGM_USGN) + const uint Md4_N_USGM_USGN = Md4 * N * UNROLL_SG_M * UNROLL_SG_N; + const uint Md4_N_USGM_USGN_d_subgroupsize = (Md4_N_USGM_USGN + subgroup_size - 1) / subgroup_size; + [[unroll]] for (uint q = 0; q < Md4_N_USGM_USGN_d_subgroupsize; q++) { - const uint zn = siq / (Nd8 * M * UNROLL_SG_M); - const uint zmij = siq % (Nd8 * M * UNROLL_SG_M); - const uint zm = zmij / (Nd8 * M); - const uint ij = zmij % (Nd8 * M); - const uint i = ij / Nd8; - const uint j = ij % Nd8; + const uint siq = si + q * subgroup_size; - const uint gm = (mi + zm) * M + i; - const uint gn = (ni + zn) * Nd8 + j; + if (Md4_N_USGM_USGN % subgroup_size == 0 || siq < Md4_N_USGM_USGN) + { + const uint zn = siq / (Md4 * N * UNROLL_SG_M); + const uint zmij = siq % (Md4 * N * UNROLL_SG_M); + const uint zm = zmij / (Md4 * N); + const uint ij = zmij % (Md4 * N); + const uint i = ij / Md4; + const uint j = ij % Md4; - if (gm < psc(GM)) + const uint gn = (ni + zn) * N + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gm * 4 < psc(GM) && gn < psc(GN)) + { + uvec2 v = tmp_o[sgi][siq]; + + const uint gi = gm * p.outhstep + gn; + + top_blob_data_4[gi] = v; + } + } + } + } + else + { + // +-N-+ + // M | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UN+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +- -+ + // | | | + // | +- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Nd4_M_USGM_USGN = Nd4 * M * UNROLL_SG_M * UNROLL_SG_N; + const uint Nd4_M_USGM_USGN_d_subgroupsize = (Nd4_M_USGM_USGN + subgroup_size - 1) / subgroup_size; + [[unroll]] for (uint q = 0; q < Nd4_M_USGM_USGN_d_subgroupsize; q++) + { + const uint siq = si + q * subgroup_size; + + if (Nd4_M_USGM_USGN % subgroup_size == 0 || siq < Nd4_M_USGM_USGN) { - uvec4 v = tmp_o[(((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * M + i) * Nd8p + j]; + const uint zn = siq / (Nd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Nd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Nd4 * M); + const uint ij = zmij % (Nd4 * M); + const uint i = ij / Nd4; + const uint j = ij % Nd4; + + const uint gm = (mi + zm) * M + i; + const uint gn = (ni + zn) * Nd4 + j; + + if (gm < psc(GM)) + { + uvec2 v = tmp_o[sgi][siq]; #if NCNN_bf16_storage || NCNN_bf16_packed - afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); - afpvec4 vcd = afpvec4(unpackBFloat2x16(v.b), unpackBFloat2x16(v.a)); + afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); #else - afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); - afpvec4 vcd = afpvec4(unpackHalf2x16(v.b), unpackHalf2x16(v.a)); + afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); #endif - const uvec4 oi4 = gm * p.outhstep + gn * 8 + uvec4(0, 1, 2, 3); - const uvec4 oi8 = oi4 + 4; + const uvec4 oi4 = gm * p.outhstep + gn * 4 + uvec4(0, 1, 2, 3); - if (gn * 8 < psc(GN)) buffer_st1(top_blob_data, oi4.r, vab.r); - if (gn * 8 + 1 < psc(GN)) buffer_st1(top_blob_data, oi4.g, vab.g); - if (gn * 8 + 2 < psc(GN)) buffer_st1(top_blob_data, oi4.b, vab.b); - if (gn * 8 + 3 < psc(GN)) buffer_st1(top_blob_data, oi4.a, vab.a); - if (gn * 8 + 4 < psc(GN)) buffer_st1(top_blob_data, oi8.r, vcd.r); - if (gn * 8 + 5 < psc(GN)) buffer_st1(top_blob_data, oi8.g, vcd.g); - if (gn * 8 + 6 < psc(GN)) buffer_st1(top_blob_data, oi8.b, vcd.b); - if (gn * 8 + 7 < psc(GN)) buffer_st1(top_blob_data, oi8.a, vcd.a); + if (gn * 4 < psc(GN)) buffer_st1(top_blob_data, oi4.r, vab.r); + if (gn * 4 + 1 < psc(GN)) buffer_st1(top_blob_data, oi4.g, vab.g); + if (gn * 4 + 2 < psc(GN)) buffer_st1(top_blob_data, oi4.b, vab.b); + if (gn * 4 + 3 < psc(GN)) buffer_st1(top_blob_data, oi4.a, vab.a); + } } } } } else { - // +-M-+ - // N | - // +SG_UM - // | | - // ^ +---+ - // | | | - // SG_UN+- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // WG_UM +- -+ - // | | | - // | +- -+ - // | | | - // ^ v +---+ - // | | | - // | +- -+ - // | | | - // | +---+ - // | | | - // | +- -+ - // | | | - // WG_UN +---+ - // | | | - // | +- -+ - // | | | - // | +---+ - // | | | - // | +- -+ - // | | | - // v +---+ - - const uint Md8_N_USGM_USGN = Md8 * N * UNROLL_SG_M * UNROLL_SG_N; - const uint Md8_N_USGM_USGN_d_subgroupsize = (Md8_N_USGM_USGN + subgroup_size - 1) / subgroup_size; - [[unroll]] for (uint q = 0; q < Md8_N_USGM_USGN_d_subgroupsize; q++) + if (psc(out_elempack) == 4) { - const uint siq = si + q * subgroup_size; - - if (Md8_N_USGM_USGN % subgroup_size == 0 || siq < Md8_N_USGM_USGN) + const uint Nd4_M_USGM_USGN = Nd4 * M * UNROLL_SG_M * UNROLL_SG_N; + const uint Nd4_M_USGM_USGN_d_subgroupsize = (Nd4_M_USGM_USGN + subgroup_size - 1) / subgroup_size; + [[unroll]] for (uint q = 0; q < Nd4_M_USGM_USGN_d_subgroupsize; q++) { - const uint zn = siq / (Md8 * N * UNROLL_SG_M); - const uint zmij = siq % (Md8 * N * UNROLL_SG_M); - const uint zm = zmij / (Md8 * N); - const uint ij = zmij % (Md8 * N); - const uint i = ij / Md8; - const uint j = ij % Md8; + const uint siq = si + q * subgroup_size; + + if (Nd4_M_USGM_USGN % subgroup_size == 0 || siq < Nd4_M_USGM_USGN) + { + const uint zn = siq / (Nd4 * M * UNROLL_SG_M); + const uint zmij = siq % (Nd4 * M * UNROLL_SG_M); + const uint zm = zmij / (Nd4 * M); + const uint ij = zmij % (Nd4 * M); + const uint i = ij / Nd4; + const uint j = ij % Nd4; - const uint gn = (ni + zn) * N + i; - const uint gm = (mi + zm) * Md8 + j; + const uint gm = (mi + zm) * M + i; + const uint gn = (ni + zn) * Nd4 + j; - if (gn < psc(GN)) + if (gn * 4 < psc(GN) && gm < psc(GM)) + { + uvec2 v = tmp_o[sgi][siq]; + + const uint gi = gn * p.outhstep + gm; + + top_blob_data_4[gi] = v; + } + } + } + } + else + { + // +-M-+ + // N | + // +SG_UM + // | | + // ^ +---+ + // | | | + // SG_UN+- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // WG_UM +- -+ + // | | | + // | +- -+ + // | | | + // ^ v +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // WG_UN +---+ + // | | | + // | +- -+ + // | | | + // | +---+ + // | | | + // | +- -+ + // | | | + // v +---+ + + const uint Md4_N_USGM_USGN = Md4 * N * UNROLL_SG_M * UNROLL_SG_N; + const uint Md4_N_USGM_USGN_d_subgroupsize = (Md4_N_USGM_USGN + subgroup_size - 1) / subgroup_size; + [[unroll]] for (uint q = 0; q < Md4_N_USGM_USGN_d_subgroupsize; q++) + { + const uint siq = si + q * subgroup_size; + + if (Md4_N_USGM_USGN % subgroup_size == 0 || siq < Md4_N_USGM_USGN) { - uvec4 v = tmp_o[(((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * N + i) * Md8p + j]; + const uint zn = siq / (Md4 * N * UNROLL_SG_M); + const uint zmij = siq % (Md4 * N * UNROLL_SG_M); + const uint zm = zmij / (Md4 * N); + const uint ij = zmij % (Md4 * N); + const uint i = ij / Md4; + const uint j = ij % Md4; + + const uint gn = (ni + zn) * N + i; + const uint gm = (mi + zm) * Md4 + j; + + if (gn < psc(GN)) + { + uvec2 v = tmp_o[sgi][siq]; #if NCNN_bf16_storage || NCNN_bf16_packed - afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); - afpvec4 vcd = afpvec4(unpackBFloat2x16(v.b), unpackBFloat2x16(v.a)); + afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); #else - afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); - afpvec4 vcd = afpvec4(unpackHalf2x16(v.b), unpackHalf2x16(v.a)); + afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); #endif - const uvec4 oi4 = gn * p.outhstep + (gm * 8 + uvec4(0, 1, 2, 3)); - const uvec4 oi8 = oi4 + 4; + const uvec4 oi4 = gn * p.outhstep + (gm * 4 + uvec4(0, 1, 2, 3)); - if (gm * 8 < psc(GM)) buffer_st1(top_blob_data, oi4.r, vab.r); - if (gm * 8 + 1 < psc(GM)) buffer_st1(top_blob_data, oi4.g, vab.g); - if (gm * 8 + 2 < psc(GM)) buffer_st1(top_blob_data, oi4.b, vab.b); - if (gm * 8 + 3 < psc(GM)) buffer_st1(top_blob_data, oi4.a, vab.a); - if (gm * 8 + 4 < psc(GM)) buffer_st1(top_blob_data, oi8.r, vcd.r); - if (gm * 8 + 5 < psc(GM)) buffer_st1(top_blob_data, oi8.g, vcd.g); - if (gm * 8 + 6 < psc(GM)) buffer_st1(top_blob_data, oi8.b, vcd.b); - if (gm * 8 + 7 < psc(GM)) buffer_st1(top_blob_data, oi8.a, vcd.a); + if (gm * 4 < psc(GM)) buffer_st1(top_blob_data, oi4.r, vab.r); + if (gm * 4 + 1 < psc(GM)) buffer_st1(top_blob_data, oi4.g, vab.g); + if (gm * 4 + 2 < psc(GM)) buffer_st1(top_blob_data, oi4.b, vab.b); + if (gm * 4 + 3 < psc(GM)) buffer_st1(top_blob_data, oi4.a, vab.a); + } } } } diff --git a/src/layer/vulkan/shader/gemm_sg.comp b/src/layer/vulkan/shader/gemm_sg.comp index 674a723da04..fdf65b9a1f0 100644 --- a/src/layer/vulkan/shader/gemm_sg.comp +++ b/src/layer/vulkan/shader/gemm_sg.comp @@ -35,6 +35,9 @@ layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; layout(binding = 1) readonly buffer A_blob { sfp A_blob_data[]; }; layout(binding = 2) readonly buffer B_blob { sfp B_blob_data[]; }; layout(binding = 3) readonly buffer C_blob { sfp C_blob_data[]; }; +layout(binding = 4) writeonly buffer top_blob_4 { sfpvec4 top_blob_data_4[]; }; +layout(binding = 5) readonly buffer A_blob_4 { sfpvec4 A_blob_data_4[]; }; +layout(binding = 6) readonly buffer B_blob_4 { sfpvec4 B_blob_data_4[]; }; layout(push_constant) uniform parameter { @@ -48,6 +51,9 @@ layout(push_constant) uniform parameter int B_hstep; int outdims; int outhstep; + int out_elempack; + int A_elempack; + int B_elempack; } p; void main() @@ -150,21 +156,64 @@ void main() { const uvec4 gy4 = wgmi * UNROLL_SG_M * 4 + smni * 4 + uvec4(0, 1, 2, 3); - const uvec4 ai4 = transA == 0 ? gy4 * p.A_hstep + slk : slk * p.A_hstep + gy4; - - a.r = buffer_ld1(A_blob_data, ai4.r); - a.g = buffer_ld1(A_blob_data, ai4.g); - a.b = buffer_ld1(A_blob_data, ai4.b); - a.a = buffer_ld1(A_blob_data, ai4.a); + if (p.A_elempack == 4) + { + if (transA == 0) + { + // A is (K, M), M packed. gy4 are consecutive M-indices. + // Single vec4 load: data[(gy4.r/4) * hstep + slk] = vec4(A[gy4.r..gy4.a][slk]) + a = buffer_ld4(A_blob_data_4, (gy4.r / 4) * p.A_hstep + slk); + } + else + { + // A is (M, K), K packed. slk is K-index, gy4 are M-indices. + const uint slk_d4 = slk / 4; + const uint slk_m4 = slk % 4; + a.r = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.r)[slk_m4]; + a.g = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.g)[slk_m4]; + a.b = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.b)[slk_m4]; + a.a = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.a)[slk_m4]; + } + } + else + { + const uvec4 ai4 = transA == 0 ? gy4 * p.A_hstep + slk : slk * p.A_hstep + gy4; + + a.r = buffer_ld1(A_blob_data, ai4.r); + a.g = buffer_ld1(A_blob_data, ai4.g); + a.b = buffer_ld1(A_blob_data, ai4.b); + a.a = buffer_ld1(A_blob_data, ai4.a); + } const uvec4 gx4 = wgni * UNROLL_SG_N * 4 + smni * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = transB == 0 ? slk * p.B_hstep + gx4 : gx4 * p.B_hstep + slk; - - b.r = buffer_ld1(B_blob_data, bi4.r); - b.g = buffer_ld1(B_blob_data, bi4.g); - b.b = buffer_ld1(B_blob_data, bi4.b); - b.a = buffer_ld1(B_blob_data, bi4.a); + if (p.B_elempack == 4) + { + if (transB == 1) + { + // B is (K, N), N packed. Single vec4 load. + b = buffer_ld4(B_blob_data_4, (gx4.r / 4) * p.B_hstep + slk); + } + else + { + // B is (N, K), K packed. slk is K-index, gx4 are N-indices. + const uint slk_d4 = slk / 4; + const uint slk_m4 = slk % 4; + b.r = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.r)[slk_m4]; + b.g = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.g)[slk_m4]; + b.b = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.b)[slk_m4]; + b.a = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.a)[slk_m4]; + } + } + else + { + const uvec4 bi4 = transB == 0 ? slk * p.B_hstep + gx4 : gx4 * p.B_hstep + slk; + + b.r = buffer_ld1(B_blob_data, bi4.r); + b.g = buffer_ld1(B_blob_data, bi4.g); + b.b = buffer_ld1(B_blob_data, bi4.b); + b.a = buffer_ld1(B_blob_data, bi4.a); + } } for (int z = 0; z < UNROLL_SG_K; z++) @@ -189,21 +238,59 @@ void main() { const uvec4 gy4 = wgmi * UNROLL_SG_M * 4 + smni * 4 + uvec4(0, 1, 2, 3); - const uvec4 ai4 = transA == 0 ? gy4 * p.A_hstep + slk : slk * p.A_hstep + gy4; - - a.r = buffer_ld1(A_blob_data, ai4.r); - a.g = buffer_ld1(A_blob_data, ai4.g); - a.b = buffer_ld1(A_blob_data, ai4.b); - a.a = buffer_ld1(A_blob_data, ai4.a); + if (p.A_elempack == 4) + { + if (transA == 0) + { + a = buffer_ld4(A_blob_data_4, (gy4.r / 4) * p.A_hstep + slk); + } + else + { + const uint slk_d4 = slk / 4; + const uint slk_m4 = slk % 4; + a.r = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.r)[slk_m4]; + a.g = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.g)[slk_m4]; + a.b = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.b)[slk_m4]; + a.a = buffer_ld4(A_blob_data_4, slk_d4 * p.A_hstep + gy4.a)[slk_m4]; + } + } + else + { + const uvec4 ai4 = transA == 0 ? gy4 * p.A_hstep + slk : slk * p.A_hstep + gy4; + + a.r = buffer_ld1(A_blob_data, ai4.r); + a.g = buffer_ld1(A_blob_data, ai4.g); + a.b = buffer_ld1(A_blob_data, ai4.b); + a.a = buffer_ld1(A_blob_data, ai4.a); + } const uvec4 gx4 = wgni * UNROLL_SG_N * 4 + smni * 4 + uvec4(0, 1, 2, 3); - const uvec4 bi4 = transB == 0 ? slk * p.B_hstep + gx4 : gx4 * p.B_hstep + slk; - - b.r = buffer_ld1(B_blob_data, bi4.r); - b.g = buffer_ld1(B_blob_data, bi4.g); - b.b = buffer_ld1(B_blob_data, bi4.b); - b.a = buffer_ld1(B_blob_data, bi4.a); + if (p.B_elempack == 4) + { + if (transB == 1) + { + b = buffer_ld4(B_blob_data_4, (gx4.r / 4) * p.B_hstep + slk); + } + else + { + const uint slk_d4 = slk / 4; + const uint slk_m4 = slk % 4; + b.r = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.r)[slk_m4]; + b.g = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.g)[slk_m4]; + b.b = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.b)[slk_m4]; + b.a = buffer_ld4(B_blob_data_4, slk_d4 * p.B_hstep + gx4.a)[slk_m4]; + } + } + else + { + const uvec4 bi4 = transB == 0 ? slk * p.B_hstep + gx4 : gx4 * p.B_hstep + slk; + + b.r = buffer_ld1(B_blob_data, bi4.r); + b.g = buffer_ld1(B_blob_data, bi4.g); + b.b = buffer_ld1(B_blob_data, bi4.b); + b.a = buffer_ld1(B_blob_data, bi4.a); + } } for (int z = 0; z < UNROLL_SG_K && k + z < psc(K); z++) @@ -234,62 +321,91 @@ void main() if (output_transpose == 1) { - const uvec4 gi4 = gx4 * p.outhstep; - - buffer_st1(top_blob_data, gi4.r + gy4.r, sum0.r); - if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.g, sum1.r); - if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.b, sum2.r); - if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.a, sum3.r); - if (gx4.g < psc(N)) + if (p.out_elempack == 4) { - buffer_st1(top_blob_data, gi4.g + gy4.r, sum0.g); - if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.g, sum1.g); - if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.b, sum2.g); - if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.a, sum3.g); + // transpose output, pack4 on N dimension + // sum_i = vec4(C[gy4[i]][gx4[0..3]]) + // store sum_i directly as vec4 + const uint gi = (gx4.r / 4) * p.outhstep + gy4.r; + + buffer_st4(top_blob_data_4, gi, sum0); + if (gy4.g < psc(M)) buffer_st4(top_blob_data_4, gi + 1, sum1); + if (gy4.b < psc(M)) buffer_st4(top_blob_data_4, gi + 2, sum2); + if (gy4.a < psc(M)) buffer_st4(top_blob_data_4, gi + 3, sum3); } - if (gx4.b < psc(N)) + else { - buffer_st1(top_blob_data, gi4.b + gy4.r, sum0.b); - if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.g, sum1.b); - if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.b, sum2.b); - if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.a, sum3.b); - } - if (gx4.a < psc(N)) - { - buffer_st1(top_blob_data, gi4.a + gy4.r, sum0.a); - if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.g, sum1.a); - if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.b, sum2.a); - if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.a, sum3.a); + const uvec4 gi4 = gx4 * p.outhstep; + + buffer_st1(top_blob_data, gi4.r + gy4.r, sum0.r); + if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.g, sum1.r); + if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.b, sum2.r); + if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.r + gy4.a, sum3.r); + if (gx4.g < psc(N)) + { + buffer_st1(top_blob_data, gi4.g + gy4.r, sum0.g); + if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.g, sum1.g); + if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.b, sum2.g); + if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.g + gy4.a, sum3.g); + } + if (gx4.b < psc(N)) + { + buffer_st1(top_blob_data, gi4.b + gy4.r, sum0.b); + if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.g, sum1.b); + if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.b, sum2.b); + if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.b + gy4.a, sum3.b); + } + if (gx4.a < psc(N)) + { + buffer_st1(top_blob_data, gi4.a + gy4.r, sum0.a); + if (gy4.g < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.g, sum1.a); + if (gy4.b < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.b, sum2.a); + if (gy4.a < psc(M)) buffer_st1(top_blob_data, gi4.a + gy4.a, sum3.a); + } } } else { - const uvec4 gi4 = gy4 * p.outhstep; - - buffer_st1(top_blob_data, gi4.r + gx4.r, sum0.r); - if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.g, sum0.g); - if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.b, sum0.b); - if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.a, sum0.a); - if (gy4.g < psc(M)) - { - buffer_st1(top_blob_data, gi4.g + gx4.r, sum1.r); - if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.g, sum1.g); - if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.b, sum1.b); - if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.a, sum1.a); - } - if (gy4.b < psc(M)) + if (p.out_elempack == 4) { - buffer_st1(top_blob_data, gi4.b + gx4.r, sum2.r); - if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.g, sum2.g); - if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.b, sum2.b); - if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.a, sum2.a); + // non-transpose output, pack4 on M dimension + // pack vec4(sum0[j], sum1[j], sum2[j], sum3[j]) for each column j + const uint gi = (gy4.r / 4) * p.outhstep + gx4.r; + + buffer_st4(top_blob_data_4, gi, afpvec4(sum0.r, sum1.r, sum2.r, sum3.r)); + if (gx4.g < psc(N)) buffer_st4(top_blob_data_4, gi + 1, afpvec4(sum0.g, sum1.g, sum2.g, sum3.g)); + if (gx4.b < psc(N)) buffer_st4(top_blob_data_4, gi + 2, afpvec4(sum0.b, sum1.b, sum2.b, sum3.b)); + if (gx4.a < psc(N)) buffer_st4(top_blob_data_4, gi + 3, afpvec4(sum0.a, sum1.a, sum2.a, sum3.a)); } - if (gy4.a < psc(M)) + else { - buffer_st1(top_blob_data, gi4.a + gx4.r, sum3.r); - if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.g, sum3.g); - if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.b, sum3.b); - if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.a, sum3.a); + const uvec4 gi4 = gy4 * p.outhstep; + + buffer_st1(top_blob_data, gi4.r + gx4.r, sum0.r); + if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.g, sum0.g); + if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.b, sum0.b); + if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.r + gx4.a, sum0.a); + if (gy4.g < psc(M)) + { + buffer_st1(top_blob_data, gi4.g + gx4.r, sum1.r); + if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.g, sum1.g); + if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.b, sum1.b); + if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.g + gx4.a, sum1.a); + } + if (gy4.b < psc(M)) + { + buffer_st1(top_blob_data, gi4.b + gx4.r, sum2.r); + if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.g, sum2.g); + if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.b, sum2.b); + if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.b + gx4.a, sum2.a); + } + if (gy4.a < psc(M)) + { + buffer_st1(top_blob_data, gi4.a + gx4.r, sum3.r); + if (gx4.g < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.g, sum3.g); + if (gx4.b < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.b, sum3.b); + if (gx4.a < psc(N)) buffer_st1(top_blob_data, gi4.a + gx4.a, sum3.a); + } } } } From 9f2e0c2a580fc02e4866b407047257f782923b8d Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 6 Mar 2026 19:00:13 +0800 Subject: [PATCH 05/36] vkmat add memory type index, is_device_local (#6581) --- src/allocator.cpp | 22 ++++++++++++++++++++++ src/allocator.h | 4 ++++ src/gpu.cpp | 7 +++++++ src/gpu.h | 1 + src/layer/vulkan/gemm_vulkan.cpp | 16 ++++++++++++++-- 5 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/allocator.cpp b/src/allocator.cpp index 137bca99bde..f8854530d4f 100644 --- a/src/allocator.cpp +++ b/src/allocator.cpp @@ -719,6 +719,7 @@ VkBufferMemory* VkBlobAllocator::fastMalloc(size_t size) ptr->memory = d->buffer_blocks[i]->memory; ptr->capacity = aligned_size; ptr->mapped_ptr = d->buffer_blocks[i]->mapped_ptr; + ptr->memory_type_index = d->buffer_blocks[i]->memory_type_index; ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -797,6 +798,8 @@ VkBufferMemory* VkBlobAllocator::fastMalloc(size_t size) vkMapMemory(vkdev->vkdevice(), block->memory, 0, new_block_size, 0, &block->mapped_ptr); } + block->memory_type_index = buffer_memory_type_index; + d->buffer_blocks.push_back(block); // return sub buffer @@ -807,6 +810,7 @@ VkBufferMemory* VkBlobAllocator::fastMalloc(size_t size) ptr->memory = block->memory; ptr->capacity = aligned_size; ptr->mapped_ptr = block->mapped_ptr; + ptr->memory_type_index = block->memory_type_index; ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -991,6 +995,7 @@ VkImageMemory* VkBlobAllocator::fastMalloc(int w, int h, int c, size_t elemsize, // do not allow host access to optimal tiling image ptr->mapped_ptr = 0; + ptr->memory_type_index = image_memory_type_index; ptr->imageview = create_imageview(ptr->image, format); @@ -1080,6 +1085,7 @@ VkImageMemory* VkBlobAllocator::fastMalloc(int w, int h, int c, size_t elemsize, // do not allow host access to optimal tiling image ptr->mapped_ptr = 0; + ptr->memory_type_index = image_memory_type_index; ptr->imageview = create_imageview(ptr->image, format); @@ -1364,6 +1370,7 @@ VkBufferMemory* VkWeightAllocator::fastMalloc(size_t size) ptr->memory = d->buffer_blocks[i]->memory; ptr->capacity = aligned_size; ptr->mapped_ptr = d->buffer_blocks[i]->mapped_ptr; + ptr->memory_type_index = d->buffer_blocks[i]->memory_type_index; ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -1455,6 +1462,8 @@ VkBufferMemory* VkWeightAllocator::fastMalloc(size_t size) vkMapMemory(vkdev->vkdevice(), block->memory, 0, new_block_size, 0, &block->mapped_ptr); } + block->memory_type_index = buffer_memory_type_index; + d->dedicated_buffer_blocks.push_back(block); // return sub buffer @@ -1465,6 +1474,7 @@ VkBufferMemory* VkWeightAllocator::fastMalloc(size_t size) ptr->memory = block->memory; ptr->capacity = new_block_size; ptr->mapped_ptr = block->mapped_ptr; + ptr->memory_type_index = block->memory_type_index; ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -1608,6 +1618,8 @@ VkBufferMemory* VkWeightAllocator::fastMalloc(size_t size) vkMapMemory(vkdev->vkdevice(), block->memory, 0, new_block_size, 0, &block->mapped_ptr); } + block->memory_type_index = buffer_memory_type_index; + d->buffer_blocks.push_back(block); d->buffer_block_free_spaces.push_back(new_block_size - aligned_size); @@ -1620,6 +1632,7 @@ VkBufferMemory* VkWeightAllocator::fastMalloc(size_t size) ptr->memory = block->memory; ptr->capacity = aligned_size; ptr->mapped_ptr = block->mapped_ptr; + ptr->memory_type_index = block->memory_type_index; ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -1774,6 +1787,7 @@ VkImageMemory* VkWeightAllocator::fastMalloc(int w, int h, int c, size_t elemsiz // do not allow host access to optimal tiling image ptr->mapped_ptr = 0; + ptr->memory_type_index = image_memory_type_index; ptr->imageview = create_imageview(ptr->image, format); @@ -1815,6 +1829,7 @@ VkImageMemory* VkWeightAllocator::fastMalloc(int w, int h, int c, size_t elemsiz // do not allow host access to optimal tiling image ptr->mapped_ptr = 0; + ptr->memory_type_index = image_memory_type_index; ptr->imageview = create_imageview(ptr->image, format); @@ -1974,6 +1989,7 @@ VkImageMemory* VkWeightAllocator::fastMalloc(int w, int h, int c, size_t elemsiz // do not allow host access to optimal tiling image ptr->mapped_ptr = 0; + ptr->memory_type_index = image_memory_type_index; ptr->imageview = create_imageview(ptr->image, format); @@ -2114,6 +2130,8 @@ VkBufferMemory* VkStagingAllocator::fastMalloc(size_t size) vkMapMemory(vkdev->vkdevice(), ptr->memory, 0, size, 0, &ptr->mapped_ptr); + ptr->memory_type_index = buffer_memory_type_index; + ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -2149,6 +2167,7 @@ VkImageMemory* VkStagingAllocator::fastMalloc(int w, int h, int c, size_t elemsi ptr->bind_capacity = size; ptr->mapped_ptr = malloc(size); + ptr->memory_type_index = (uint32_t)-1; ptr->imageview = 0; @@ -2229,6 +2248,8 @@ VkBufferMemory* VkWeightStagingAllocator::fastMalloc(size_t size) vkMapMemory(vkdev->vkdevice(), ptr->memory, 0, size, 0, &ptr->mapped_ptr); + ptr->memory_type_index = buffer_memory_type_index; + ptr->access_flags = 0; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; @@ -2418,6 +2439,7 @@ VkImageMemory* VkAndroidHardwareBufferImageAllocator::fastMalloc(int /*w*/, int ptr->memory = memory; ptr->imageview = imageview; ptr->mapped_ptr = 0; + ptr->memory_type_index = (uint32_t)-1; ptr->access_flags = 0; ptr->image_layout = VK_IMAGE_LAYOUT_UNDEFINED; ptr->stage_flags = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; diff --git a/src/allocator.h b/src/allocator.h index 8c634c39502..cf8db861495 100644 --- a/src/allocator.h +++ b/src/allocator.h @@ -221,6 +221,8 @@ class NCNN_EXPORT VkBufferMemory VkDeviceMemory memory; void* mapped_ptr; + uint32_t memory_type_index; + // buffer state, modified by command functions internally mutable VkAccessFlags access_flags; mutable VkPipelineStageFlags stage_flags; @@ -244,6 +246,8 @@ class NCNN_EXPORT VkImageMemory VkDeviceMemory memory; void* mapped_ptr; + uint32_t memory_type_index; + // the base offset assigned by allocator size_t bind_offset; size_t bind_capacity; diff --git a/src/gpu.cpp b/src/gpu.cpp index ebe3b84fc6d..f43cbdc2f4f 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -4308,6 +4308,13 @@ bool VulkanDevice::is_coherent(uint32_t memory_type_index) const return memoryType.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; } +bool VulkanDevice::is_device_local(uint32_t memory_type_index) const +{ + const VkMemoryType& memoryType = info.physicalDeviceMemoryProperties().memoryTypes[memory_type_index]; + + return memoryType.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; +} + VkQueue VulkanDevice::acquire_queue(uint32_t queue_family_index) const { if (queue_family_index != info.compute_queue_family_index() && queue_family_index != info.transfer_queue_family_index()) diff --git a/src/gpu.h b/src/gpu.h index 8ff0fa54664..8839a076d3c 100644 --- a/src/gpu.h +++ b/src/gpu.h @@ -445,6 +445,7 @@ class NCNN_EXPORT VulkanDevice uint32_t find_memory_index(uint32_t memory_type_bits, VkFlags required, VkFlags preferred, VkFlags preferred_not) const; bool is_mappable(uint32_t memory_type_index) const; bool is_coherent(uint32_t memory_type_index) const; + bool is_device_local(uint32_t memory_type_index) const; VkQueue acquire_queue(uint32_t queue_family_index) const; void reclaim_queue(uint32_t queue_family_index, VkQueue queue) const; diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 06cb843d3e7..d1b9991d513 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -643,8 +643,20 @@ int Gemm_vulkan::upload_model(VkTransfer& cmd, const Option& opt) int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const { - const VkMat& A = constantA ? A_data_gpu : bottom_blobs[0]; - const VkMat& B = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1]; + const VkMat& A0 = constantA ? A_data_gpu : bottom_blobs[0]; + const VkMat& B0 = constantB ? B_data_gpu : constantA ? bottom_blobs[0] : bottom_blobs[1]; + + VkMat A = A0; + VkMat B = B0; + + if (constantA && !vkdev->is_device_local(A0.data->memory_type_index)) + { + cmd.record_clone(A0, A, opt); + } + if (constantB && !vkdev->is_device_local(B0.data->memory_type_index)) + { + cmd.record_clone(B0, B, opt); + } const int A_elempack = A.elempack; const int B_elempack = B.elempack; From 09dfd1bfe8d4c0332a29aa61f01687b8bbaa6653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=A8=E7=9A=84=E5=BD=B7=E5=BE=A8?= <56149058+futz12@users.noreply.github.com> Date: Sun, 8 Mar 2026 00:05:11 +0800 Subject: [PATCH 06/36] fmod logaddexp floor_divide remainder support for binaryop (#6549) --- src/layer/arm/binaryop_arm.cpp | 40 +- src/layer/arm/binaryop_arm_asimdhp.cpp | 103 ++++ src/layer/arm/neon_mathfun.h | 89 ++++ src/layer/binaryop.cpp | 72 +++ src/layer/binaryop.h | 9 +- src/layer/loongarch/binaryop_loongarch.cpp | 23 + src/layer/loongarch/lsx_mathfun.h | 61 +++ src/layer/mips/binaryop_mips.cpp | 21 + src/layer/mips/msa_mathfun.h | 61 +++ src/layer/riscv/binaryop_riscv.cpp | 22 + src/layer/riscv/binaryop_riscv_zfh.cpp | 22 + src/layer/riscv/rvv_mathfun.h | 125 +++++ src/layer/riscv/rvv_mathfun_fp16s.h | 125 +++++ src/layer/vulkan/binaryop_vulkan.cpp | 7 + src/layer/vulkan/shader/binaryop.comp | 7 + .../vulkan/shader/binaryop_broadcast.comp | 7 + .../shader/binaryop_broadcast_pack1to4.comp | 7 + .../shader/binaryop_broadcast_pack4.comp | 7 + src/layer/vulkan/shader/binaryop_pack4.comp | 7 + src/layer/x86/avx512_mathfun.h | 44 ++ src/layer/x86/avx_mathfun.h | 44 ++ src/layer/x86/binaryop_x86.cpp | 199 ++++++++ src/layer/x86/sse_mathfun.h | 64 +++ src/simplemath.cpp | 49 ++ src/simplemath.h | 2 + tests/test_binaryop.cpp | 2 +- tests/test_binaryop_1.cpp | 2 +- tests/test_binaryop_2.cpp | 2 +- tests/test_binaryop_3.cpp | 4 +- tests/test_binaryop_4.cpp | 482 ++++++++++++++++++ .../pnnx/src/pass_ncnn/expand_expression.cpp | 12 +- 31 files changed, 1709 insertions(+), 12 deletions(-) create mode 100644 tests/test_binaryop_4.cpp diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index d90db5b93fc..ef5562dd4c0 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -272,8 +272,15 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vdivq_f32(y, x)) MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x)) #endif MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x)) -MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y)) -MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_atan2, atan2f(x, y), atan2_ps(x, y)) +MAKE_FUNCTION(binary_op_ratan2, atan2f(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_fmod, (float)fmodf(x, y), fmod_ps(x, y)) +MAKE_FUNCTION(binary_op_rfmod, (float)fmodf(y, x), fmod_ps(y, x)) +MAKE_FUNCTION(binary_op_logaddexp, (float)(std::max(x, y) + log1pf(expf(std::min(x, y) - std::max(x, y)))), logaddexp_ps(x, y)) +MAKE_FUNCTION(binary_op_floor_divide, (float)floorf(x / y), floor_divide_ps(x, y)) +MAKE_FUNCTION(binary_op_rfloor_divide, (float)floorf(y / x), floor_divide_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y)) +MAKE_FUNCTION(binary_op_rremainder, (float)remainderf(y, x), remainder_ps(y, x)) // *INDENT-ON* // clang-format on @@ -297,6 +304,13 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -441,10 +455,18 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } @@ -844,6 +866,13 @@ static void binary_op_vector_bf16s(const unsigned short* ptr, const unsigned sho if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -889,6 +918,13 @@ static void binary_op_vector_scalar_b_bf16s(const unsigned short* ptr, float b, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_scalar_b_bf16s(ptr, b, outptr, size); // should never reach here } diff --git a/src/layer/arm/binaryop_arm_asimdhp.cpp b/src/layer/arm/binaryop_arm_asimdhp.cpp index 0860e2ac2b6..d5440474d58 100644 --- a/src/layer/arm/binaryop_arm_asimdhp.cpp +++ b/src/layer/arm/binaryop_arm_asimdhp.cpp @@ -12,6 +12,87 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static inline float16x4_t fmod_f16(const float16x4_t& x, const float16x4_t& y) +{ + float32x4_t fx = vcvt_f32_f16(x); + float32x4_t fy = vcvt_f32_f16(y); + return vcvt_f16_f32(fmod_ps(fx, fy)); +} + +static inline float16x8_t fmodq_f16(const float16x8_t& x, const float16x8_t& y) +{ + float16x4_t xl = vget_low_f16(x); + float16x4_t xh = vget_high_f16(x); + float16x4_t yl = vget_low_f16(y); + float16x4_t yh = vget_high_f16(y); + + float16x4_t rl = fmod_f16(xl, yl); + float16x4_t rh = fmod_f16(xh, yh); + return vcombine_f16(rl, rh); +} + +static inline float16x4_t round_f16(const float16x4_t& x) +{ + return vcvt_f16_f32(round_ps(vcvt_f32_f16(x))); +} + +static inline float16x8_t roundq_f16(const float16x8_t& x) +{ + float16x4_t xl = vget_low_f16(x); + float16x4_t xh = vget_high_f16(x); + float16x4_t rl = round_f16(xl); + float16x4_t rh = round_f16(xh); + return vcombine_f16(rl, rh); +} + +static inline float16x4_t logaddexp_f16(const float16x4_t& x, const float16x4_t& y) +{ + return vcvt_f16_f32(logaddexp_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))); +} + +static inline float16x8_t logaddexpq_f16(const float16x8_t& x, const float16x8_t& y) +{ + float16x4_t xl = vget_low_f16(x); + float16x4_t xh = vget_high_f16(x); + float16x4_t yl = vget_low_f16(y); + float16x4_t yh = vget_high_f16(y); + float16x4_t rl = logaddexp_f16(xl, yl); + float16x4_t rh = logaddexp_f16(xh, yh); + return vcombine_f16(rl, rh); +} + +static inline float16x4_t floor_divide_f16(const float16x4_t& x, const float16x4_t& y) +{ + return vcvt_f16_f32(floor_divide_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))); +} + +static inline float16x8_t floor_divideq_f16(const float16x8_t& x, const float16x8_t& y) +{ + float16x4_t xl = vget_low_f16(x); + float16x4_t xh = vget_high_f16(x); + float16x4_t yl = vget_low_f16(y); + float16x4_t yh = vget_high_f16(y); + float16x4_t rl = floor_divide_f16(xl, yl); + float16x4_t rh = floor_divide_f16(xh, yh); + return vcombine_f16(rl, rh); +} + +static inline float16x4_t remainder_f16(const float16x4_t& x, const float16x4_t& y) +{ + return vcvt_f16_f32(remainder_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))); +} + +static inline float16x8_t remainderq_f16(const float16x8_t& x, const float16x8_t& y) +{ + float16x4_t xl = vget_low_f16(x); + float16x4_t xh = vget_high_f16(x); + float16x4_t yl = vget_low_f16(y); + float16x4_t yh = vget_high_f16(y); + float16x4_t rl = remainder_f16(xl, yl); + float16x4_t rh = remainder_f16(xh, yh); + return vcombine_f16(rl, rh); +} + template static void binary_op_vector_no_broadcast_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16* outptr, int size) { @@ -318,6 +399,13 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x)) MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)powf(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x)))))) MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2f(x, y), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y)))))) MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2f(y, x), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x)))))) +MAKE_FUNCTION(binary_op_fmod_fp16s, (__fp16)fmodf((float)x, (float)y), fmod_f16(x, y), fmodq_f16(x, y)) +MAKE_FUNCTION(binary_op_rfmod_fp16s, (__fp16)fmodf((float)y, (float)x), fmod_f16(y, x), fmodq_f16(y, x)) +MAKE_FUNCTION(binary_op_logaddexp_fp16s, (__fp16)(std::max((float)x, (float)y) + log1pf(expf(std::min((float)x, (float)y) - std::max((float)x, (float)y)))), logaddexp_f16(x, y), logaddexpq_f16(x, y)) +MAKE_FUNCTION(binary_op_floor_divide_fp16s, (__fp16)floorf((float)x / (float)y), floor_divide_f16(x, y), floor_divideq_f16(x, y)) +MAKE_FUNCTION(binary_op_rfloor_divide_fp16s, (__fp16)floorf((float)y / (float)x), floor_divide_f16(y, x), floor_divideq_f16(y, x)) +MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf((float)x, (float)y), remainder_f16(x, y), remainderq_f16(x, y)) +MAKE_FUNCTION(binary_op_rremainder_fp16s, (__fp16)remainderf((float)y, (float)x), remainder_f16(y, x), remainderq_f16(y, x)) // *INDENT-ON* // clang-format on @@ -341,6 +429,13 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16 if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -485,10 +580,18 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } diff --git a/src/layer/arm/neon_mathfun.h b/src/layer/arm/neon_mathfun.h index ba9c17aaec5..8697c993646 100644 --- a/src/layer/arm/neon_mathfun.h +++ b/src/layer/arm/neon_mathfun.h @@ -404,6 +404,95 @@ static inline float32x4_t atan2_ps(float32x4_t a, float32x4_t b) return vld1q_f32(tmpx); } +static inline float32x4_t trunc_ps(const float32x4_t& x) +{ + // truncate toward zero +#if __aarch64__ + return vrndq_f32(x); +#else + int32x4_t xi = vcvtq_s32_f32(x); + return vcvtq_f32_s32(xi); +#endif +} + +static inline float32x4_t fmod_ps(const float32x4_t& x, const float32x4_t& y) +{ + // fmod(x,y) = x - trunc(x/y) * y +#if __aarch64__ + float32x4_t q = vdivq_f32(x, y); +#else + float32x4_t q = div_ps(x, y); +#endif + float32x4_t tq = trunc_ps(q); + return vsubq_f32(x, vmulq_f32(tq, y)); +} + +static inline float32x4_t round_ps(const float32x4_t& x) +{ +#if __aarch64__ + return vrndnq_f32(x); +#else + float32x4_t half = vdupq_n_f32(0.5f); + float32x4_t one = vdupq_n_f32(1.0f); + uint32x4_t sign_mask = vcltq_f32(x, vdupq_n_f32(0)); + float32x4_t abs_x = vabsq_f32(x); + int32x4_t xi = vcvtq_s32_f32(abs_x); + float32x4_t truncated = vcvtq_f32_s32(xi); + float32x4_t diff = vsubq_f32(abs_x, truncated); + uint32x4_t diff_gt_half = vcgtq_f32(diff, half); + uint32x4_t diff_eq_half = vceqq_f32(diff, half); + int32x4_t xi_and_1 = vandq_s32(xi, vdupq_n_s32(1)); + uint32x4_t is_odd = vcgtq_s32(xi_and_1, vdupq_n_s32(0)); + uint32x4_t round_up = vorrq_u32(diff_gt_half, vandq_u32(diff_eq_half, is_odd)); + float32x4_t rounded = vaddq_f32(truncated, vreinterpretq_f32_u32(vandq_u32(round_up, vreinterpretq_u32_f32(one)))); + return vbslq_f32(sign_mask, vnegq_f32(rounded), rounded); +#endif +} + +static inline float32x4_t logaddexp_ps(const float32x4_t& x, const float32x4_t& y) +{ + float32x4_t max_xy = vmaxq_f32(x, y); + float32x4_t min_xy = vminq_f32(x, y); + float32x4_t diff = vsubq_f32(min_xy, max_xy); + float32x4_t exp_diff = exp_ps(diff); + float32x4_t one_plus_exp = vaddq_f32(vdupq_n_f32(1.0f), exp_diff); + float32x4_t log_result = log_ps(one_plus_exp); + return vaddq_f32(max_xy, log_result); +} + +static inline float32x4_t floor_ps(const float32x4_t& x) +{ +#if __aarch64__ + return vrndmq_f32(x); +#else + float32x4_t truncated = vcvtq_f32_s32(vcvtq_s32_f32(x)); + uint32x4_t need_adjust = vcltq_f32(x, truncated); + float32x4_t adjusted = vsubq_f32(truncated, vdupq_n_f32(1.0f)); + return vbslq_f32(need_adjust, adjusted, truncated); +#endif +} + +static inline float32x4_t floor_divide_ps(const float32x4_t& x, const float32x4_t& y) +{ +#if __aarch64__ + float32x4_t q = vdivq_f32(x, y); +#else + float32x4_t q = div_ps(x, y); +#endif + return floor_ps(q); +} + +static inline float32x4_t remainder_ps(const float32x4_t& x, const float32x4_t& y) +{ +#if __aarch64__ + float32x4_t q = vdivq_f32(x, y); +#else + float32x4_t q = div_ps(x, y); +#endif + float32x4_t rq = round_ps(q); + return vsubq_f32(x, vmulq_f32(rq, y)); +} + #include "neon_mathfun_tanh.h" // Clean up macros diff --git a/src/layer/binaryop.cpp b/src/layer/binaryop.cpp index 75ed67ae61e..1bf0a7f1327 100644 --- a/src/layer/binaryop.cpp +++ b/src/layer/binaryop.cpp @@ -226,6 +226,64 @@ struct binary_op_ratan2 } }; +struct binary_op_fmod +{ + float operator()(const float& x, const float& y) const + { + return (float)fmodf(x, y); + } +}; + +struct binary_op_logaddexp +{ + float operator()(const float& x, const float& y) const + { + float max_xy = std::max(x, y); + float min_xy = std::min(x, y); + return (float)(max_xy + log1pf(expf(min_xy - max_xy))); + } +}; + +struct binary_op_floor_divide +{ + float operator()(const float& x, const float& y) const + { + return (float)floorf(x / y); + } +}; + +struct binary_op_remainder +{ + float operator()(const float& x, const float& y) const + { + return (float)remainderf(x, y); + } +}; + +struct binary_op_rfmod +{ + float operator()(const float& x, const float& y) const + { + return (float)fmodf(y, x); + } +}; + +struct binary_op_rfloor_divide +{ + float operator()(const float& x, const float& y) const + { + return (float)floorf(y / x); + } +}; + +struct binary_op_rremainder +{ + float operator()(const float& x, const float& y) const + { + return (float)remainderf(y, x); + } +}; + static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast(a, b, c, opt); @@ -240,6 +298,13 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast(b, a, c, opt); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast(a, b, c, opt); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast(b, a, c, opt); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_broadcast(b, a, c, opt); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_broadcast(b, a, c, opt); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_broadcast(b, a, c, opt); // should never reach here } @@ -258,6 +323,13 @@ static void binary_op_scalar_inplace(Mat& bottom_top_blob, float b, int op_type, if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_inplace(bottom_top_blob, b, opt); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_scalar_inplace(bottom_top_blob, b, opt); // should never reach here } diff --git a/src/layer/binaryop.h b/src/layer/binaryop.h index 20706cfe1c0..adcc9196b52 100644 --- a/src/layer/binaryop.h +++ b/src/layer/binaryop.h @@ -34,7 +34,14 @@ class BinaryOp : public Layer Operation_RDIV = 8, Operation_RPOW = 9, Operation_ATAN2 = 10, - Operation_RATAN2 = 11 + Operation_RATAN2 = 11, + Operation_FMOD = 12, + Operation_RFMOD = 13, + Operation_LOGADDEXP = 14, + Operation_FLOOR_DIVIDE = 15, + Operation_RFLOOR_DIVIDE = 16, + Operation_REMAINDER = 17, + Operation_RREMAINDER = 18 }; public: diff --git a/src/layer/loongarch/binaryop_loongarch.cpp b/src/layer/loongarch/binaryop_loongarch.cpp index 6f7ed45c535..adfd181b156 100644 --- a/src/layer/loongarch/binaryop_loongarch.cpp +++ b/src/layer/loongarch/binaryop_loongarch.cpp @@ -301,6 +301,13 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x)) MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x)) MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_fmod, (float)fmodf(x, y), fmod_ps(x, y)) +MAKE_FUNCTION(binary_op_rfmod, (float)fmodf(y, x), fmod_ps(y, x)) +MAKE_FUNCTION(binary_op_logaddexp, (float)(std::max(x, y) + log1pf(expf(std::min(x, y) - std::max(x, y)))), logaddexp_ps(x, y)) +MAKE_FUNCTION(binary_op_floor_divide, (float)floorf(x / y), floor_divide_ps(x, y)) +MAKE_FUNCTION(binary_op_rfloor_divide, (float)floorf(y / x), floor_divide_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y)) +MAKE_FUNCTION(binary_op_rremainder, (float)remainderf(y, x), remainder_ps(y, x)) // *INDENT-ON* // clang-format on @@ -324,6 +331,13 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -468,10 +482,19 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_LOGADDEXP) return BinaryOp::Operation_LOGADDEXP; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } diff --git a/src/layer/loongarch/lsx_mathfun.h b/src/layer/loongarch/lsx_mathfun.h index 5f30bdad7a3..aeb7f188ce4 100644 --- a/src/layer/loongarch/lsx_mathfun.h +++ b/src/layer/loongarch/lsx_mathfun.h @@ -626,4 +626,65 @@ static inline __m128 atan2_ps(__m128 y, __m128 x) return final_result; } +static inline __m128 fmod_ps(__m128 a, __m128 b) +{ + // fmod(a,b) = a - trunc(a/b)*b (trunc toward 0) + __m128 q = __lsx_vfdiv_s(a, b); + __m128i qi = __lsx_vftintrz_w_s(q); // float -> int32 trunc toward zero + __m128 qf = __lsx_vffint_s_w(qi); // int32 -> float + return __lsx_vfsub_s(a, __lsx_vfmul_s(qf, b)); +} + +static inline __m128 round_ps(__m128 x) +{ + __m128 half = (__m128)__lsx_vreplgr2vr_w(c_0p5.i); + __m128 one = (__m128)__lsx_vreplgr2vr_w(c_1.i); + __m128i sign_mask = __lsx_vfcmp_clt_s(x, (__m128)__lsx_vreplgr2vr_w(0)); + __m128 abs_x = (__m128)__lsx_vbitclri_w((__m128i)x, 31); + __m128i xi = __lsx_vftintrz_w_s(abs_x); + __m128 xf = __lsx_vffint_s_w(xi); + __m128 diff = __lsx_vfsub_s(abs_x, xf); + __m128i diff_gt_half = __lsx_vfcmp_clt_s(half, diff); + __m128i diff_eq_half = __lsx_vfcmp_ceq_s(diff, half); + __m128i xi_and_1 = __lsx_vand_v(xi, __lsx_vreplgr2vr_w(1)); + __m128i is_odd = __lsx_vseq_w(xi_and_1, __lsx_vreplgr2vr_w(1)); + __m128i round_up = __lsx_vor_v(diff_gt_half, __lsx_vand_v(diff_eq_half, is_odd)); + __m128 rounded = __lsx_vfadd_s(xf, (__m128)__lsx_vand_v(round_up, (__m128i)one)); + return (__m128)__lsx_vbitsel_v((__m128i)rounded, (__m128i)__lsx_vbitrevi_w((__m128i)rounded, 31), sign_mask); +} + +static inline __m128 logaddexp_ps(__m128 a, __m128 b) +{ + __m128 one = (__m128)__lsx_vreplgr2vr_w(c_1.i); + __m128 max_xy = __lsx_vfmax_s(a, b); + __m128 min_xy = __lsx_vfmin_s(a, b); + __m128 diff = __lsx_vfsub_s(min_xy, max_xy); + __m128 exp_diff = exp_ps(diff); + __m128 one_plus_exp = __lsx_vfadd_s(one, exp_diff); + __m128 log_result = log_ps(one_plus_exp); + return __lsx_vfadd_s(max_xy, log_result); +} + +static inline __m128 floor_ps(__m128 x) +{ + __m128i xi = __lsx_vftintrz_w_s(x); + __m128 xf = __lsx_vffint_s_w(xi); + __m128i need_adjust = __lsx_vfcmp_clt_s(x, xf); + __m128 one = (__m128)__lsx_vreplgr2vr_w(c_1.i); + return __lsx_vfsub_s(xf, (__m128)__lsx_vand_v(need_adjust, (__m128i)one)); +} + +static inline __m128 floor_divide_ps(__m128 a, __m128 b) +{ + __m128 q = __lsx_vfdiv_s(a, b); + return floor_ps(q); +} + +static inline __m128 remainder_ps(__m128 a, __m128 b) +{ + __m128 q = __lsx_vfdiv_s(a, b); + __m128 rq = round_ps(q); + return __lsx_vfsub_s(a, __lsx_vfmul_s(rq, b)); +} + #endif // LSX_MATHFUN_H diff --git a/src/layer/mips/binaryop_mips.cpp b/src/layer/mips/binaryop_mips.cpp index b734adf4d99..a2766abfed1 100644 --- a/src/layer/mips/binaryop_mips.cpp +++ b/src/layer/mips/binaryop_mips.cpp @@ -301,6 +301,13 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x)) MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x)) MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_fmod, (float)fmodf(x, y), fmod_ps(x, y)) +MAKE_FUNCTION(binary_op_rfmod, (float)fmodf(y, x), fmod_ps(y, x)) +MAKE_FUNCTION(binary_op_logaddexp, (float)(std::max(x, y) + log1pf(expf(std::min(x, y) - std::max(x, y)))), logaddexp_ps(x, y)) +MAKE_FUNCTION(binary_op_floor_divide, (float)floorf(x / y), floor_divide_ps(x, y)) +MAKE_FUNCTION(binary_op_rfloor_divide, (float)floorf(y / x), floor_divide_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y)) +MAKE_FUNCTION(binary_op_rremainder, (float)remainderf(y, x), remainder_ps(y, x)) // *INDENT-ON* // clang-format on @@ -324,6 +331,13 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -468,10 +482,17 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; return op_type; } diff --git a/src/layer/mips/msa_mathfun.h b/src/layer/mips/msa_mathfun.h index 8abfe282568..cc4b65fe0d6 100644 --- a/src/layer/mips/msa_mathfun.h +++ b/src/layer/mips/msa_mathfun.h @@ -267,4 +267,65 @@ static inline v4f32 atan2_ps(v4f32 a, v4f32 b) return (v4f32)__msa_ld_w(tmpx, 0); } +static inline v4f32 fmod_ps(v4f32 a, v4f32 b) +{ + // fmod(a,b) = a - trunc(a/b)*b (trunc toward 0) + v4f32 q = __msa_fdiv_w(a, b); + v4i32 qi = __msa_ftrunc_s_w(q); // trunc toward zero (independent of RM) + v4f32 qf = __msa_ffint_s_w(qi); + return __msa_fsub_w(a, __msa_fmul_w(qf, b)); +} + +static inline v4f32 round_ps(v4f32 x) +{ + v4f32 half = (v4f32)__msa_fill_w(c_0p5.i); + v4f32 one = (v4f32)__msa_fill_w(c_1.i); + v4i32 sign_mask = __msa_fclt_w(x, (v4f32)__msa_fill_w(0)); + v4f32 abs_x = (v4f32)__msa_bclri_w((v4u32)x, 31); + v4i32 xi = __msa_ftrunc_s_w(abs_x); + v4f32 xf = __msa_ffint_s_w(xi); + v4f32 diff = __msa_fsub_w(abs_x, xf); + v4i32 diff_gt_half = __msa_fclt_w(half, diff); + v4i32 diff_eq_half = __msa_fceq_w(diff, half); + v4i32 xi_and_1 = (v4i32)__msa_and_v((v16u8)xi, (v16u8)__msa_fill_w(1)); + v4i32 is_odd = __msa_ceqi_w(xi_and_1, 1); + v4i32 round_up = (v4i32)__msa_or_v((v16u8)diff_gt_half, __msa_and_v((v16u8)diff_eq_half, (v16u8)is_odd)); + v4f32 rounded = __msa_fadd_w(xf, (v4f32)__msa_and_v((v16u8)one, (v16u8)round_up)); + return (v4f32)__msa_bsel_v((v16u8)sign_mask, (v16u8)rounded, (v16u8)__msa_bnegi_w((v4u32)rounded, 31)); +} + +static inline v4f32 logaddexp_ps(v4f32 a, v4f32 b) +{ + v4f32 one = (v4f32)__msa_fill_w(c_1.i); + v4f32 max_xy = __msa_fmax_w(a, b); + v4f32 min_xy = __msa_fmin_w(a, b); + v4f32 diff = __msa_fsub_w(min_xy, max_xy); + v4f32 exp_diff = exp_ps(diff); + v4f32 one_plus_exp = __msa_fadd_w(one, exp_diff); + v4f32 log_result = log_ps(one_plus_exp); + return __msa_fadd_w(max_xy, log_result); +} + +static inline v4f32 floor_ps(v4f32 x) +{ + v4i32 xi = __msa_ftrunc_s_w(x); + v4f32 xf = __msa_ffint_s_w(xi); + v4i32 need_adjust = __msa_fclt_w(x, xf); + v4f32 one = (v4f32)__msa_fill_w(c_1.i); + return __msa_fsub_w(xf, (v4f32)__msa_and_v((v16u8)one, (v16u8)need_adjust)); +} + +static inline v4f32 floor_divide_ps(v4f32 a, v4f32 b) +{ + v4f32 q = __msa_fdiv_w(a, b); + return floor_ps(q); +} + +static inline v4f32 remainder_ps(v4f32 a, v4f32 b) +{ + v4f32 q = __msa_fdiv_w(a, b); + v4f32 rq = round_ps(q); + return __msa_fsub_w(a, __msa_fmul_w(rq, b)); +} + #endif // MSA_MATHFUN_H diff --git a/src/layer/riscv/binaryop_riscv.cpp b/src/layer/riscv/binaryop_riscv.cpp index 7d57c5b89f8..f0a372836dd 100644 --- a/src/layer/riscv/binaryop_riscv.cpp +++ b/src/layer/riscv/binaryop_riscv.cpp @@ -283,6 +283,13 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __riscv_vfdiv_vv_f32m8(y, x, vl), __riscv_v MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x, vl), pow_ps(__riscv_vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, __riscv_vfmv_v_f_f32m8(x, vl), vl)) MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y, vl), atan2_ps(x, __riscv_vfmv_v_f_f32m8(y, vl), vl), atan2_ps(__riscv_vfmv_v_f_f32m8(x, vl), y, vl)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x, vl), atan2_ps(__riscv_vfmv_v_f_f32m8(y, vl), x, vl), atan2_ps(y, __riscv_vfmv_v_f_f32m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_fmod, (float)fmodf(x, y), fmod_ps(x, y, vl), fmod_ps(x, __riscv_vfmv_v_f_f32m8(y, vl), vl), fmod_ps(__riscv_vfmv_v_f_f32m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rfmod, (float)fmodf(y, x), fmod_ps(y, x, vl), fmod_ps(__riscv_vfmv_v_f_f32m8(y, vl), x, vl), fmod_ps(y, __riscv_vfmv_v_f_f32m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_logaddexp, (float)(std::max(x, y) + log1pf(expf(std::min(x, y) - std::max(x, y)))), logaddexp_ps(x, y, vl), logaddexp_ps(x, __riscv_vfmv_v_f_f32m8(y, vl), vl), logaddexp_ps(__riscv_vfmv_v_f_f32m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_floor_divide, (float)floorf(x / y), floor_divide_ps(x, y, vl), floor_divide_ps(x, __riscv_vfmv_v_f_f32m8(y, vl), vl), floor_divide_ps(__riscv_vfmv_v_f_f32m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rfloor_divide, (float)floorf(y / x), floor_divide_ps(y, x, vl), floor_divide_ps(__riscv_vfmv_v_f_f32m8(y, vl), x, vl), floor_divide_ps(y, __riscv_vfmv_v_f_f32m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y, vl), remainder_ps(x, __riscv_vfmv_v_f_f32m8(y, vl), vl), remainder_ps(__riscv_vfmv_v_f_f32m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rremainder, (float)remainderf(y, x), remainder_ps(y, x, vl), remainder_ps(__riscv_vfmv_v_f_f32m8(y, vl), x, vl), remainder_ps(y, __riscv_vfmv_v_f_f32m8(x, vl), vl)) // *INDENT-ON* // clang-format on @@ -306,6 +313,13 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -450,10 +464,18 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } diff --git a/src/layer/riscv/binaryop_riscv_zfh.cpp b/src/layer/riscv/binaryop_riscv_zfh.cpp index 7287798af74..e9045db804b 100644 --- a/src/layer/riscv/binaryop_riscv_zfh.cpp +++ b/src/layer/riscv/binaryop_riscv_zfh.cpp @@ -271,6 +271,13 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, __riscv_vfdiv_vv_f16m8(y, x, vl), __r MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)powf((float)y, (float)x), pow_ps(y, x, vl), pow_ps(__riscv_vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, __riscv_vfmv_v_f_f16m8(x, vl), vl)) MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2f((float)x, (float)y), atan2_ps(x, y, vl), atan2_ps(x, __riscv_vfmv_v_f_f16m8(y, vl), vl), atan2_ps(__riscv_vfmv_v_f_f16m8(x, vl), y, vl)) MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2f((float)y, (float)x), atan2_ps(y, x, vl), atan2_ps(__riscv_vfmv_v_f_f16m8(y, vl), x, vl), atan2_ps(y, __riscv_vfmv_v_f_f16m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_fmod_fp16s, (__fp16)fmodf((float)x, (float)y), fmod_ps(x, y, vl), fmod_ps(x, __riscv_vfmv_v_f_f16m8(y, vl), vl), fmod_ps(__riscv_vfmv_v_f_f16m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rfmod_fp16s, (__fp16)fmodf((float)y, (float)x), fmod_ps(y, x, vl), fmod_ps(__riscv_vfmv_v_f_f16m8(y, vl), x, vl), fmod_ps(y, __riscv_vfmv_v_f_f16m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_logaddexp_fp16s, (__fp16)(std::max((float)x, (float)y) + log1pf(expf(std::min((float)x, (float)y) - std::max((float)x, (float)y)))), logaddexp_ps(x, y, vl), logaddexp_ps(x, __riscv_vfmv_v_f_f16m8(y, vl), vl), logaddexp_ps(__riscv_vfmv_v_f_f16m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_floor_divide_fp16s, (__fp16)floorf((float)x / (float)y), floor_divide_ps(x, y, vl), floor_divide_ps(x, __riscv_vfmv_v_f_f16m8(y, vl), vl), floor_divide_ps(__riscv_vfmv_v_f_f16m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rfloor_divide_fp16s, (__fp16)floorf((float)y / (float)x), floor_divide_ps(y, x, vl), floor_divide_ps(__riscv_vfmv_v_f_f16m8(y, vl), x, vl), floor_divide_ps(y, __riscv_vfmv_v_f_f16m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf((float)x, (float)y), remainder_ps(x, y, vl), remainder_ps(x, __riscv_vfmv_v_f_f16m8(y, vl), vl), remainder_ps(__riscv_vfmv_v_f_f16m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rremainder_fp16s, (__fp16)remainderf((float)y, (float)x), remainder_ps(y, x, vl), remainder_ps(__riscv_vfmv_v_f_f16m8(y, vl), x, vl), remainder_ps(y, __riscv_vfmv_v_f_f16m8(x, vl), vl)) // *INDENT-ON* // clang-format on @@ -294,6 +301,13 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16 if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -438,10 +452,18 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } diff --git a/src/layer/riscv/rvv_mathfun.h b/src/layer/riscv/rvv_mathfun.h index 4dc3244aa95..6a5a9ff1570 100644 --- a/src/layer/riscv/rvv_mathfun.h +++ b/src/layer/riscv/rvv_mathfun.h @@ -567,4 +567,129 @@ _RVV_FLOAT32_ATAN2_OP(2, 16) _RVV_FLOAT32_ATAN2_OP(4, 8) _RVV_FLOAT32_ATAN2_OP(8, 4) +/* fmod(a,b) = a - trunc(a/b)*b (trunc toward 0) */ +#if __riscv_xtheadvector +// simulate trunc with floor positives and ceil negative +// xi = round(x) +// floorx = xi - (xi > x) +// ceilx = xi + (xi < x) +// truncx = x >= 0 ? floorx : ceilx +#define _RVV_FLOAT32_FMOD_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t fmod_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \ + { \ + vfloat32m##LMUL##_t q = __riscv_vfdiv_vv_f32m##LMUL(a, b, vl); \ + vint32m##LMUL##_t qi = __riscv_vfcvt_x_f_v_i32m##LMUL(q, vl); \ + vfloat32m##LMUL##_t qf = __riscv_vfcvt_f_x_v_f32m##LMUL(qi, vl); \ + vbool##MLEN##_t _floormask = __riscv_vmfgt_vv_f32m##LMUL##_b##MLEN(qf, q, vl); \ + vint32m##LMUL##_t _floorx = __riscv_vsub_vx_i32m##LMUL##_mu(_floormask, qi, qi, 1, vl); \ + vbool##MLEN##_t _ceilmask = __riscv_vmflt_vv_f32m##LMUL##_b##MLEN(qf, q, vl); \ + vint32m##LMUL##_t _ceilx = __riscv_vadd_vx_i32m##LMUL##_mu(_ceilmask, qi, qi, 1, vl); \ + vbool##MLEN##_t _negative = __riscv_vmflt_vf_f32m##LMUL##_b##MLEN(q, 0.f, vl); \ + vint32m##LMUL##_t trunc_qi = __riscv_vmerge_vvm_i32m##LMUL(_floorx, _ceilx, _negative, vl); \ + vfloat32m##LMUL##_t trunc_q = __riscv_vfcvt_f_x_v_f32m##LMUL(trunc_qi, vl); \ + return __riscv_vfsub_vv_f32m##LMUL(a, __riscv_vfmul_vv_f32m##LMUL(trunc_q, b, vl), vl); \ + } +#else +#define _RVV_FLOAT32_FMOD_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t fmod_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \ + { \ + vfloat32m##LMUL##_t q = __riscv_vfdiv_vv_f32m##LMUL(a, b, vl); \ + vint32m##LMUL##_t qi = __riscv_vfcvt_rtz_x_f_v_i32m##LMUL(q, vl); \ + vfloat32m##LMUL##_t qf = __riscv_vfcvt_f_x_v_f32m##LMUL(qi, vl); \ + return __riscv_vfsub_vv_f32m##LMUL(a, __riscv_vfmul_vv_f32m##LMUL(qf, b, vl), vl); \ + } +#endif + +_RVV_FLOAT32_FMOD_OP(1, 32) +_RVV_FLOAT32_FMOD_OP(2, 16) +_RVV_FLOAT32_FMOD_OP(4, 8) +_RVV_FLOAT32_FMOD_OP(8, 4) + +/* round to nearest, ties to even (banker's rounding) */ +#define _RVV_FLOAT32_ROUND_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t round_ps(vfloat32m##LMUL##_t x, size_t vl) \ + { \ + vfloat32m##LMUL##_t absx = __riscv_vfsgnjx_vv_f32m##LMUL(x, x, vl); \ + vfloat32m##LMUL##_t half = __riscv_vfmv_v_f_f32m##LMUL(0.5f, vl); \ + vint32m##LMUL##_t xi = __riscv_vfcvt_x_f_v_i32m##LMUL(absx, vl); \ + vfloat32m##LMUL##_t xf = __riscv_vfcvt_f_x_v_f32m##LMUL(xi, vl); \ + vfloat32m##LMUL##_t diff = __riscv_vfsub_vv_f32m##LMUL(absx, xf, vl); \ + vbool##MLEN##_t diff_gt_half = __riscv_vmfgt_vv_f32m##LMUL##_b##MLEN(diff, half, vl); \ + vbool##MLEN##_t diff_eq_half = __riscv_vmfeq_vv_f32m##LMUL##_b##MLEN(diff, half, vl); \ + vint32m##LMUL##_t one_i = __riscv_vmv_v_x_i32m##LMUL(1, vl); \ + vint32m##LMUL##_t xi_and_1 = __riscv_vand_vv_i32m##LMUL(xi, one_i, vl); \ + vbool##MLEN##_t is_odd = __riscv_vmsne_vx_i32m##LMUL##_b##MLEN(xi_and_1, 0, vl); \ + vbool##MLEN##_t round_up = __riscv_vmor_mm_b##MLEN(diff_gt_half, \ + __riscv_vmand_mm_b##MLEN(diff_eq_half, is_odd, vl), vl); \ + vfloat32m##LMUL##_t one = __riscv_vfmv_v_f_f32m##LMUL(1.f, vl); \ + vfloat32m##LMUL##_t rounded = __riscv_vfadd_vv_f32m##LMUL##_mu(round_up, xf, xf, one, vl); \ + return __riscv_vfsgnj_vv_f32m##LMUL(rounded, x, vl); \ + } + +_RVV_FLOAT32_ROUND_OP(1, 32) +_RVV_FLOAT32_ROUND_OP(2, 16) +_RVV_FLOAT32_ROUND_OP(4, 8) +_RVV_FLOAT32_ROUND_OP(8, 4) + +/* logaddexp(a,b) = max(a,b) + log1p(exp(min(a,b) - max(a,b))) */ +#define _RVV_FLOAT32_LOGADDEXP_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t logaddexp_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \ + { \ + vfloat32m##LMUL##_t max_xy = __riscv_vfmax_vv_f32m##LMUL(a, b, vl); \ + vfloat32m##LMUL##_t min_xy = __riscv_vfmin_vv_f32m##LMUL(a, b, vl); \ + vfloat32m##LMUL##_t diff = __riscv_vfsub_vv_f32m##LMUL(min_xy, max_xy, vl); \ + vfloat32m##LMUL##_t exp_diff = exp_ps(diff, vl); \ + vfloat32m##LMUL##_t one = __riscv_vfmv_v_f_f32m##LMUL(1.f, vl); \ + vfloat32m##LMUL##_t one_plus_exp = __riscv_vfadd_vv_f32m##LMUL(one, exp_diff, vl); \ + vfloat32m##LMUL##_t log_result = log_ps(one_plus_exp, vl); \ + return __riscv_vfadd_vv_f32m##LMUL(max_xy, log_result, vl); \ + } + +_RVV_FLOAT32_LOGADDEXP_OP(1, 32) +_RVV_FLOAT32_LOGADDEXP_OP(2, 16) +_RVV_FLOAT32_LOGADDEXP_OP(4, 8) +_RVV_FLOAT32_LOGADDEXP_OP(8, 4) + +/* floor */ +#define _RVV_FLOAT32_FLOOR_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t floor_ps(vfloat32m##LMUL##_t x, size_t vl) \ + { \ + vint32m##LMUL##_t xi = __riscv_vfcvt_x_f_v_i32m##LMUL(x, vl); \ + vfloat32m##LMUL##_t xf = __riscv_vfcvt_f_x_v_f32m##LMUL(xi, vl); \ + vbool##MLEN##_t need_adjust = __riscv_vmfgt_vv_f32m##LMUL##_b##MLEN(xf, x, vl); \ + vfloat32m##LMUL##_t one = __riscv_vfmv_v_f_f32m##LMUL(1.f, vl); \ + return __riscv_vfsub_vv_f32m##LMUL##_mu(need_adjust, xf, xf, one, vl); \ + } + +_RVV_FLOAT32_FLOOR_OP(1, 32) +_RVV_FLOAT32_FLOOR_OP(2, 16) +_RVV_FLOAT32_FLOOR_OP(4, 8) +_RVV_FLOAT32_FLOOR_OP(8, 4) + +#define _RVV_FLOAT32_FLOOR_DIVIDE_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t floor_divide_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \ + { \ + vfloat32m##LMUL##_t q = __riscv_vfdiv_vv_f32m##LMUL(a, b, vl); \ + return floor_ps(q, vl); \ + } + +_RVV_FLOAT32_FLOOR_DIVIDE_OP(1, 32) +_RVV_FLOAT32_FLOOR_DIVIDE_OP(2, 16) +_RVV_FLOAT32_FLOOR_DIVIDE_OP(4, 8) +_RVV_FLOAT32_FLOOR_DIVIDE_OP(8, 4) + +/* remainder(a,b) = a - round(a/b) * b */ +#define _RVV_FLOAT32_REMAINDER_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t remainder_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \ + { \ + vfloat32m##LMUL##_t q = __riscv_vfdiv_vv_f32m##LMUL(a, b, vl); \ + vfloat32m##LMUL##_t rq = round_ps(q, vl); \ + return __riscv_vfsub_vv_f32m##LMUL(a, __riscv_vfmul_vv_f32m##LMUL(rq, b, vl), vl); \ + } + +_RVV_FLOAT32_REMAINDER_OP(1, 32) +_RVV_FLOAT32_REMAINDER_OP(2, 16) +_RVV_FLOAT32_REMAINDER_OP(4, 8) +_RVV_FLOAT32_REMAINDER_OP(8, 4) + #endif // RVV_MATHFUN_H diff --git a/src/layer/riscv/rvv_mathfun_fp16s.h b/src/layer/riscv/rvv_mathfun_fp16s.h index e0068264dd9..df60312888f 100644 --- a/src/layer/riscv/rvv_mathfun_fp16s.h +++ b/src/layer/riscv/rvv_mathfun_fp16s.h @@ -403,4 +403,129 @@ _RVV_FLOAT16_ATAN2_OP(2, 8) _RVV_FLOAT16_ATAN2_OP(4, 4) _RVV_FLOAT16_ATAN2_OP(8, 2) +/* fmod(a,b) = a - trunc(a/b)*b (trunc toward 0) */ +#if __riscv_xtheadvector +// simulate trunc with floor positives and ceil negative +// xi = round(x) +// floorx = xi - (xi > x) +// ceilx = xi + (xi < x) +// truncx = x >= 0 ? floorx : ceilx +#define _RVV_FLOAT16_FMOD_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t fmod_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \ + { \ + vfloat16m##LMUL##_t q = __riscv_vfdiv_vv_f16m##LMUL(a, b, vl); \ + vint16m##LMUL##_t qi = __riscv_vfcvt_x_f_v_i16m##LMUL(q, vl); \ + vfloat16m##LMUL##_t qf = __riscv_vfcvt_f_x_v_f16m##LMUL(qi, vl); \ + vbool##MLEN##_t _floormask = __riscv_vmfgt_vv_f16m##LMUL##_b##MLEN(qf, q, vl); \ + vint16m##LMUL##_t _floorx = __riscv_vsub_vx_i16m##LMUL##_mu(_floormask, qi, qi, 1, vl); \ + vbool##MLEN##_t _ceilmask = __riscv_vmflt_vv_f16m##LMUL##_b##MLEN(qf, q, vl); \ + vint16m##LMUL##_t _ceilx = __riscv_vadd_vx_i16m##LMUL##_mu(_ceilmask, qi, qi, 1, vl); \ + vbool##MLEN##_t _negative = __riscv_vmflt_vf_f16m##LMUL##_b##MLEN(q, (__fp16)0.f, vl); \ + vint16m##LMUL##_t trunc_qi = __riscv_vmerge_vvm_i16m##LMUL(_floorx, _ceilx, _negative, vl); \ + vfloat16m##LMUL##_t trunc_q = __riscv_vfcvt_f_x_v_f16m##LMUL(trunc_qi, vl); \ + return __riscv_vfsub_vv_f16m##LMUL(a, __riscv_vfmul_vv_f16m##LMUL(trunc_q, b, vl), vl); \ + } +#else +#define _RVV_FLOAT16_FMOD_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t fmod_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \ + { \ + vfloat16m##LMUL##_t q = __riscv_vfdiv_vv_f16m##LMUL(a, b, vl); \ + vint16m##LMUL##_t qi = __riscv_vfcvt_rtz_x_f_v_i16m##LMUL(q, vl); \ + vfloat16m##LMUL##_t qf = __riscv_vfcvt_f_x_v_f16m##LMUL(qi, vl); \ + return __riscv_vfsub_vv_f16m##LMUL(a, __riscv_vfmul_vv_f16m##LMUL(qf, b, vl), vl); \ + } +#endif + +_RVV_FLOAT16_FMOD_OP(1, 16) +_RVV_FLOAT16_FMOD_OP(2, 8) +_RVV_FLOAT16_FMOD_OP(4, 4) +_RVV_FLOAT16_FMOD_OP(8, 2) + +/* round to nearest, ties to even (banker's rounding) */ +#define _RVV_FLOAT16_ROUND_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t round_ps(vfloat16m##LMUL##_t x, size_t vl) \ + { \ + vfloat16m##LMUL##_t absx = __riscv_vfsgnjx_vv_f16m##LMUL(x, x, vl); \ + vfloat16m##LMUL##_t half = __riscv_vfmv_v_f_f16m##LMUL((__fp16)0.5f, vl); \ + vint16m##LMUL##_t xi = __riscv_vfcvt_x_f_v_i16m##LMUL(absx, vl); \ + vfloat16m##LMUL##_t xf = __riscv_vfcvt_f_x_v_f16m##LMUL(xi, vl); \ + vfloat16m##LMUL##_t diff = __riscv_vfsub_vv_f16m##LMUL(absx, xf, vl); \ + vbool##MLEN##_t diff_gt_half = __riscv_vmfgt_vv_f16m##LMUL##_b##MLEN(diff, half, vl); \ + vbool##MLEN##_t diff_eq_half = __riscv_vmfeq_vv_f16m##LMUL##_b##MLEN(diff, half, vl); \ + vint16m##LMUL##_t one_i = __riscv_vmv_v_x_i16m##LMUL(1, vl); \ + vint16m##LMUL##_t xi_and_1 = __riscv_vand_vv_i16m##LMUL(xi, one_i, vl); \ + vbool##MLEN##_t is_odd = __riscv_vmsne_vx_i16m##LMUL##_b##MLEN(xi_and_1, 0, vl); \ + vbool##MLEN##_t round_up = __riscv_vmor_mm_b##MLEN(diff_gt_half, \ + __riscv_vmand_mm_b##MLEN(diff_eq_half, is_odd, vl), vl); \ + vfloat16m##LMUL##_t one = __riscv_vfmv_v_f_f16m##LMUL((__fp16)1.f, vl); \ + vfloat16m##LMUL##_t rounded = __riscv_vfadd_vv_f16m##LMUL##_mu(round_up, xf, xf, one, vl); \ + return __riscv_vfsgnj_vv_f16m##LMUL(rounded, x, vl); \ + } + +_RVV_FLOAT16_ROUND_OP(1, 16) +_RVV_FLOAT16_ROUND_OP(2, 8) +_RVV_FLOAT16_ROUND_OP(4, 4) +_RVV_FLOAT16_ROUND_OP(8, 2) + +/* logaddexp(a,b) = max(a,b) + log1p(exp(min(a,b) - max(a,b))) */ +#define _RVV_FLOAT16_LOGADDEXP_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t logaddexp_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \ + { \ + vfloat16m##LMUL##_t max_xy = __riscv_vfmax_vv_f16m##LMUL(a, b, vl); \ + vfloat16m##LMUL##_t min_xy = __riscv_vfmin_vv_f16m##LMUL(a, b, vl); \ + vfloat16m##LMUL##_t diff = __riscv_vfsub_vv_f16m##LMUL(min_xy, max_xy, vl); \ + vfloat16m##LMUL##_t exp_diff = exp_ps(diff, vl); \ + vfloat16m##LMUL##_t one = __riscv_vfmv_v_f_f16m##LMUL((__fp16)1.f, vl); \ + vfloat16m##LMUL##_t one_plus_exp = __riscv_vfadd_vv_f16m##LMUL(one, exp_diff, vl); \ + vfloat16m##LMUL##_t log_result = log_ps(one_plus_exp, vl); \ + return __riscv_vfadd_vv_f16m##LMUL(max_xy, log_result, vl); \ + } + +_RVV_FLOAT16_LOGADDEXP_OP(1, 16) +_RVV_FLOAT16_LOGADDEXP_OP(2, 8) +_RVV_FLOAT16_LOGADDEXP_OP(4, 4) +_RVV_FLOAT16_LOGADDEXP_OP(8, 2) + +/* floor */ +#define _RVV_FLOAT16_FLOOR_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t floor_ps(vfloat16m##LMUL##_t x, size_t vl) \ + { \ + vint16m##LMUL##_t xi = __riscv_vfcvt_x_f_v_i16m##LMUL(x, vl); \ + vfloat16m##LMUL##_t xf = __riscv_vfcvt_f_x_v_f16m##LMUL(xi, vl); \ + vbool##MLEN##_t need_adjust = __riscv_vmfgt_vv_f16m##LMUL##_b##MLEN(xf, x, vl); \ + vfloat16m##LMUL##_t one = __riscv_vfmv_v_f_f16m##LMUL((__fp16)1.f, vl); \ + return __riscv_vfsub_vv_f16m##LMUL##_mu(need_adjust, xf, xf, one, vl); \ + } + +_RVV_FLOAT16_FLOOR_OP(1, 16) +_RVV_FLOAT16_FLOOR_OP(2, 8) +_RVV_FLOAT16_FLOOR_OP(4, 4) +_RVV_FLOAT16_FLOOR_OP(8, 2) + +#define _RVV_FLOAT16_FLOOR_DIVIDE_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t floor_divide_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \ + { \ + vfloat16m##LMUL##_t q = __riscv_vfdiv_vv_f16m##LMUL(a, b, vl); \ + return floor_ps(q, vl); \ + } + +_RVV_FLOAT16_FLOOR_DIVIDE_OP(1, 16) +_RVV_FLOAT16_FLOOR_DIVIDE_OP(2, 8) +_RVV_FLOAT16_FLOOR_DIVIDE_OP(4, 4) +_RVV_FLOAT16_FLOOR_DIVIDE_OP(8, 2) + +/* remainder(a,b) = a - round(a/b) * b */ +#define _RVV_FLOAT16_REMAINDER_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t remainder_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \ + { \ + vfloat16m##LMUL##_t q = __riscv_vfdiv_vv_f16m##LMUL(a, b, vl); \ + vfloat16m##LMUL##_t rq = round_ps(q, vl); \ + return __riscv_vfsub_vv_f16m##LMUL(a, __riscv_vfmul_vv_f16m##LMUL(rq, b, vl), vl); \ + } + +_RVV_FLOAT16_REMAINDER_OP(1, 16) +_RVV_FLOAT16_REMAINDER_OP(2, 8) +_RVV_FLOAT16_REMAINDER_OP(4, 4) +_RVV_FLOAT16_REMAINDER_OP(8, 2) + #endif // RVV_MATHFUN_FP16S_H diff --git a/src/layer/vulkan/binaryop_vulkan.cpp b/src/layer/vulkan/binaryop_vulkan.cpp index 10504839478..f0f662098f1 100644 --- a/src/layer/vulkan/binaryop_vulkan.cpp +++ b/src/layer/vulkan/binaryop_vulkan.cpp @@ -29,10 +29,17 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; return op_type; } diff --git a/src/layer/vulkan/shader/binaryop.comp b/src/layer/vulkan/shader/binaryop.comp index d53e051c1bd..e5c09bcf241 100644 --- a/src/layer/vulkan/shader/binaryop.comp +++ b/src/layer/vulkan/shader/binaryop.comp @@ -101,6 +101,13 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; + if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; + if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 15) res = floor(v1 / v2); + if (op_type == 16) res = floor(v2 / v1); + if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; + if (op_type == 18) res = v2 - roundEven(v2 / v1) * v1; buffer_st1(top_blob_data, gi, res); } diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp index d77470cbf09..c71d512d35d 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast.comp @@ -170,6 +170,13 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; + if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; + if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 15) res = floor(v1 / v2); + if (op_type == 16) res = floor(v2 / v1); + if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; + if (op_type == 18) res = v2 - roundEven(v2 / v1) * v1; int gi = gz * psc(outcstep) + gy * psc(outw) + gx; buffer_st1(top_blob_data, gi, res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp index 9f5db7e966b..fe45c61ad75 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp @@ -101,6 +101,13 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; + if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; + if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 15) res = floor(v1 / v2); + if (op_type == 16) res = floor(v2 / v1); + if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; + if (op_type == 18) res = v2 - roundEven(v2 / v1) * v1; int gi = gz * psc(outcstep) + gy * psc(outw) + gx; buffer_st4(top_blob_data, gi, res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp index 9b08f2c09b7..de9df649188 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp @@ -170,6 +170,13 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; + if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; + if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 15) res = floor(v1 / v2); + if (op_type == 16) res = floor(v2 / v1); + if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; + if (op_type == 18) res = v2 - roundEven(v2 / v1) * v1; int gi = gz * psc(outcstep) + gy * psc(outw) + gx; buffer_st4(top_blob_data, gi, res); diff --git a/src/layer/vulkan/shader/binaryop_pack4.comp b/src/layer/vulkan/shader/binaryop_pack4.comp index 75870b4324e..3a4c230acb2 100644 --- a/src/layer/vulkan/shader/binaryop_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_pack4.comp @@ -96,6 +96,13 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; + if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; + if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 15) res = floor(v1 / v2); + if (op_type == 16) res = floor(v2 / v1); + if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; + if (op_type == 18) res = v2 - roundEven(v2 / v1) * v1; buffer_st4(top_blob_data, gi, res); } diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index 0423ecedcf4..0068edf79c7 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -854,4 +854,48 @@ static NCNN_FORCEINLINE __m512 abs512_ps(const __m512& x) return _mm512_and_ps(abs_mask, x); } +static NCNN_FORCEINLINE __m512 trunc512_ps(const __m512& x) +{ + // truncate toward zero + return _mm512_roundscale_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); +} + +static NCNN_FORCEINLINE __m512 fmod512_ps(const __m512& x, const __m512& y) +{ + __m512 q = _mm512_div_ps(x, y); + __m512 tq = trunc512_ps(q); + return _mm512_sub_ps(x, _mm512_mul_ps(tq, y)); +} + +static NCNN_FORCEINLINE __m512 round512_ps(const __m512& x) +{ + return _mm512_roundscale_ps(x, _MM_FROUND_NINT); +} + +static NCNN_FORCEINLINE __m512 logaddexp512_ps(const __m512& x, const __m512& y) +{ + const __m512 magic_one = _mm512_set1_ps(1.0f); + + __m512 max_xy = _mm512_max_ps(x, y); + __m512 min_xy = _mm512_min_ps(x, y); + __m512 diff = _mm512_sub_ps(min_xy, max_xy); + __m512 exp_diff = exp512_ps(diff); + __m512 one_plus_exp = _mm512_add_ps(magic_one, exp_diff); + __m512 log_result = log512_ps(one_plus_exp); + return _mm512_add_ps(max_xy, log_result); +} + +static NCNN_FORCEINLINE __m512 floor_divide512_ps(const __m512& x, const __m512& y) +{ + __m512 q = _mm512_div_ps(x, y); + return _mm512_roundscale_ps(q, _MM_FROUND_TO_NEG_INF); +} + +static NCNN_FORCEINLINE __m512 remainder512_ps(const __m512& x, const __m512& y) +{ + __m512 q = _mm512_div_ps(x, y); + __m512 rq = round512_ps(q); + return _mm512_sub_ps(x, _mm512_mul_ps(rq, y)); +} + #endif // AVX512_MATHFUN_H diff --git a/src/layer/x86/avx_mathfun.h b/src/layer/x86/avx_mathfun.h index fcace35a129..4f5ef64012b 100644 --- a/src/layer/x86/avx_mathfun.h +++ b/src/layer/x86/avx_mathfun.h @@ -1096,4 +1096,48 @@ static NCNN_FORCEINLINE __m256 abs256_ps(const __m256& x) return _mm256_and_ps(abs_mask, x); } +static NCNN_FORCEINLINE __m256 trunc256_ps(const __m256& x) +{ + // truncate toward zero + return _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); +} + +static NCNN_FORCEINLINE __m256 fmod256_ps(const __m256& x, const __m256& y) +{ + __m256 q = _mm256_div_ps(x, y); + __m256 tq = trunc256_ps(q); + return _mm256_sub_ps(x, _mm256_mul_ps(tq, y)); +} + +static NCNN_FORCEINLINE __m256 round256_ps(const __m256& x) +{ + return _mm256_round_ps(x, _MM_FROUND_NINT); +} + +static NCNN_FORCEINLINE __m256 logaddexp256_ps(const __m256& x, const __m256& y) +{ + const __m256 magic_one = _mm256_set1_ps(1.0f); + + __m256 max_xy = _mm256_max_ps(x, y); + __m256 min_xy = _mm256_min_ps(x, y); + __m256 diff = _mm256_sub_ps(min_xy, max_xy); + __m256 exp_diff = exp256_ps(diff); + __m256 one_plus_exp = _mm256_add_ps(magic_one, exp_diff); + __m256 log_result = log256_ps(one_plus_exp); + return _mm256_add_ps(max_xy, log_result); +} + +static NCNN_FORCEINLINE __m256 floor_divide256_ps(const __m256& x, const __m256& y) +{ + __m256 q = _mm256_div_ps(x, y); + return _mm256_floor_ps(q); +} + +static NCNN_FORCEINLINE __m256 remainder256_ps(const __m256& x, const __m256& y) +{ + __m256 q = _mm256_div_ps(x, y); + __m256 rq = round256_ps(q); + return _mm256_sub_ps(x, _mm256_mul_ps(rq, y)); +} + #endif // AVX_MATHFUN_H diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index 06dd4e254c7..ef03cd37428 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -774,6 +774,190 @@ struct binary_op_ratan2 #endif // __SSE2__ }; +struct binary_op_fmod +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)fmodf(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return fmod_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return fmod256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return fmod512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rfmod +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)fmodf(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return fmod_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return fmod256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return fmod512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_logaddexp +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + float max_xy = std::max(x, y); + float min_xy = std::min(x, y); + return (float)(max_xy + log1pf(expf(min_xy - max_xy))); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return logaddexp_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return logaddexp256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return logaddexp512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_floor_divide +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)floorf(x / y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return floor_divide_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return floor_divide256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return floor_divide512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rfloor_divide +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)floorf(y / x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return floor_divide_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return floor_divide256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return floor_divide512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_remainder +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)remainderf(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return remainder_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return remainder256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return remainder512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rremainder +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)remainderf(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return remainder_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return remainder256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return remainder512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + } // namespace BinaryOp_x86_functor static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type) @@ -792,6 +976,13 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -936,10 +1127,18 @@ static int get_reverse_op_type(int op_type) if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + return op_type; } diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index 4744a55eaf2..2fc3ae3e27d 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -1166,4 +1166,68 @@ static NCNN_FORCEINLINE __m128 abs_ps(const __m128& x) return _mm_and_ps(abs_mask, x); } +static NCNN_FORCEINLINE __m128 trunc_ps(const __m128& x) +{ + // truncate toward zero + __m128i xi = _mm_cvttps_epi32(x); + return _mm_cvtepi32_ps(xi); +} + +static NCNN_FORCEINLINE __m128 fmod_ps(const __m128& x, const __m128& y) +{ + __m128 q = _mm_div_ps(x, y); + __m128 tq = trunc_ps(q); + return _mm_sub_ps(x, _mm_mul_ps(tq, y)); +} + +static NCNN_FORCEINLINE __m128 round_ps(const __m128& x) +{ +#if __SSE4_1__ + return _mm_round_ps(x, _MM_FROUND_NINT); +#endif // __SSE4_1__ + + const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); + const __m128 magic_half = _mm_set_ps1(0.5f); + const __m128 magic_one = _mm_set_ps1(1.0f); + + __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); + __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); + __m128i xi = _mm_cvttps_epi32(absolute); + __m128 truncated = _mm_cvtepi32_ps(xi); + __m128 diff = _mm_sub_ps(absolute, truncated); + __m128 diff_gt_half = _mm_cmpgt_ps(diff, magic_half); + __m128 diff_eq_half = _mm_cmpeq_ps(diff, magic_half); + __m128i xi_and_1 = _mm_and_si128(xi, _mm_set1_epi32(1)); + __m128i is_odd = _mm_cmpeq_epi32(xi_and_1, _mm_set1_epi32(1)); + __m128 round_up = _mm_or_ps(diff_gt_half, _mm_and_ps(diff_eq_half, _mm_castsi128_ps(is_odd))); + __m128 rounded = _mm_add_ps(truncated, _mm_and_ps(round_up, magic_one)); + return _mm_or_ps(rounded, negative_mask); +} + +static NCNN_FORCEINLINE __m128 logaddexp_ps(const __m128& x, const __m128& y) +{ + const __m128 magic_one = _mm_set_ps1(1.0f); + + __m128 max_xy = _mm_max_ps(x, y); + __m128 min_xy = _mm_min_ps(x, y); + __m128 diff = _mm_sub_ps(min_xy, max_xy); + __m128 exp_diff = exp_ps(diff); + __m128 one_plus_exp = _mm_add_ps(magic_one, exp_diff); + __m128 log_result = log_ps(one_plus_exp); + return _mm_add_ps(max_xy, log_result); +} + +static NCNN_FORCEINLINE __m128 floor_divide_ps(const __m128& x, const __m128& y) +{ + __m128 q = _mm_div_ps(x, y); + return floor_ps(q); +} + +static NCNN_FORCEINLINE __m128 remainder_ps(const __m128& x, const __m128& y) +{ + __m128 q = _mm_div_ps(x, y); + __m128 rq = round_ps(q); + return _mm_sub_ps(x, _mm_mul_ps(rq, y)); +} + #endif // SSE_MATHFUN_H diff --git a/src/simplemath.cpp b/src/simplemath.cpp index 2fc73d3fede..fe34ff77190 100644 --- a/src/simplemath.cpp +++ b/src/simplemath.cpp @@ -131,6 +131,33 @@ float fmodf(float x, float y) return (x < 0) ? -m : m; } +float remainderf(float x, float y) +{ + if (y == 0.0f) + { + return x; + } + float q = x / y; + float rq; + float absq = fabsf(q); + float intpart = floorf(absq); + float fracpart = absq - intpart; + if (fracpart > 0.5f) + { + intpart += 1.0f; + } + else if (fracpart == 0.5f) + { + int n = (int)intpart; + if (n % 2 != 0) + { + intpart += 1.0f; + } + } + rq = (q >= 0) ? intpart : -intpart; + return x - rq * y; +} + /* * ==================================================== * trigonometric functions @@ -497,6 +524,28 @@ float log10f(float x) return logf(x) / ln10; } +float log1pf(float x) +{ + if (x == 0.0f) + { + return x; + } + if (x < -1.0f) + { + return (x - x) / (x - x); // NaN + } + if (x == -1.0f) + { + return -INFINITY; + } + float u = 1.0f + x; + if (u == 1.0f) + { + return x; + } + return logf(u) * (x / (u - 1.0f)); +} + /* * ==================================================== * probability functions diff --git a/src/simplemath.h b/src/simplemath.h index 561deb124e0..1ad5fae4d10 100644 --- a/src/simplemath.h +++ b/src/simplemath.h @@ -29,6 +29,7 @@ NCNN_EXPORT float fmaxf(float, float); NCNN_EXPORT float truncf(float); NCNN_EXPORT float frac(float); NCNN_EXPORT float fmodf(float, float); +NCNN_EXPORT float remainderf(float, float); /* * ==================================================== * trigonometric functions @@ -67,6 +68,7 @@ NCNN_EXPORT float frexp(float, int*); NCNN_EXPORT float logf(float); NCNN_EXPORT float log(float); NCNN_EXPORT float log10f(float); +NCNN_EXPORT float log1pf(float); /* * ==================================================== diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index a5cc698f779..176d0ae980b 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -3,7 +3,7 @@ #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 14 static int op_type = 0; diff --git a/tests/test_binaryop_1.cpp b/tests/test_binaryop_1.cpp index 6ef3dc2accf..7951249f443 100644 --- a/tests/test_binaryop_1.cpp +++ b/tests/test_binaryop_1.cpp @@ -3,7 +3,7 @@ #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 19 static int op_type = 0; diff --git a/tests/test_binaryop_2.cpp b/tests/test_binaryop_2.cpp index 87cf7e1f679..bc152495508 100644 --- a/tests/test_binaryop_2.cpp +++ b/tests/test_binaryop_2.cpp @@ -3,7 +3,7 @@ #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 19 static int op_type = 0; diff --git a/tests/test_binaryop_3.cpp b/tests/test_binaryop_3.cpp index c7233d71da8..fdfd00b7559 100644 --- a/tests/test_binaryop_3.cpp +++ b/tests/test_binaryop_3.cpp @@ -3,7 +3,7 @@ #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 19 static int op_type = 0; @@ -370,7 +370,7 @@ int main() { SRAND(7767517); - for (op_type = 9; op_type < OP_TYPE_MAX; op_type++) + for (op_type = 9; op_type < 12; op_type++) { int ret = 0 || test_binaryop_1() diff --git a/tests/test_binaryop_4.cpp b/tests/test_binaryop_4.cpp new file mode 100644 index 00000000000..74a8521db8d --- /dev/null +++ b/tests/test_binaryop_4.cpp @@ -0,0 +1,482 @@ +// Copyright 2020 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +#define OP_TYPE_MAX 19 + +static int op_type = 0; + +static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b, int flag) +{ + ncnn::Mat a = _a; + ncnn::Mat b = _b; + if (op_type == 12 || op_type == 13) + { + // value must be non-zero for fmod/rfmod + a = a.clone(); + b = b.clone(); + + if (op_type == 12) + { + // fmod(a, b) -> b must be non-zero + for (int i = 0; i < b.total(); i++) + { + if (b[i] == 0.f) + b[i] = 0.001f; + } + } + else + { + // rfmod(a, b) = fmod(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + else if (op_type == 16 || op_type == 17) + { + // value must be non-zero for floor_divide/rfloor_divide + a = a.clone(); + b = b.clone(); + + if (op_type == 16) + { + // floor_divide(a, b) -> b must be non-zero + for (int i = 0; i < b.total(); i++) + { + if (b[i] == 0.f) + b[i] = 0.001f; + } + } + else + { + // rfloor_divide(a, b) = floor_divide(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + else if (op_type == 18 || op_type == 19) + { + // value must be non-zero for remainder/rremainder + a = a.clone(); + b = b.clone(); + + if (op_type == 18) + { + // remainder(a, b) -> b must be non-zero + for (int i = 0; i < b.total(); i++) + { + if (b[i] == 0.f) + b[i] = 0.001f; + } + } + else + { + // rremainder(a, b) = remainder(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + + ncnn::ParamDict pd; + pd.set(0, op_type); + pd.set(1, 0); // with_scalar + pd.set(2, 0.f); // b + + std::vector weights(0); + + std::vector ab(2); + ab[0] = a; + ab[1] = b; + + int ret = test_layer("BinaryOp", pd, weights, ab, 1, 0.0001, flag); + if (ret != 0) + { + fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c, op_type); + } + + return ret; +} + +static int test_binaryop(const ncnn::Mat& _a, float b, int flag) +{ + ncnn::Mat a = _a; + if (op_type == 12 || op_type == 13) + { + // value must be non-zero for fmod/rfmod + a = a.clone(); + + if (op_type == 12) + { + // fmod(a, b) -> b must be non-zero + if (b == 0.f) + b = 0.001f; + } + else + { + // rfmod(a, b) = fmod(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + else if (op_type == 16 || op_type == 17) + { + // value must be non-zero for floor_divide/rfloor_divide + a = a.clone(); + if (op_type == 16) + { + // floor_divide(a, b) -> b must be non-zero + if (b == 0.f) + b = 0.001f; + } + else + { + // rfloor_divide(a, b) = floor_divide(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + else if (op_type == 18 || op_type == 19) + { + // value must be non-zero for remainder/rremainder + a = a.clone(); + if (op_type == 18) + { + // remainder(a, b) -> b must be non-zero + if (b == 0.f) + b = 0.001f; + } + else + { + // rremainder(a, b) = remainder(b, a) -> a must be non-zero + for (int i = 0; i < a.total(); i++) + { + if (a[i] == 0.f) + a[i] = 0.001f; + } + } + } + + ncnn::ParamDict pd; + pd.set(0, op_type); + pd.set(1, 1); // with_scalar + pd.set(2, b); // b + + std::vector weights(0); + + int ret = test_layer("BinaryOp", pd, weights, a, 0.0001, flag); + if (ret != 0) + { + fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b=%f op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b, op_type); + } + + return ret; +} + +static int test_binaryop_1() +{ + const int ws[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int flag = w == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[2]; + ncnn::Mat b[2]; + for (int j = 0; j < 2; j++) + { + int bw = j % 2 == 0 ? w : 1; + a[j] = RandomMat(bw, 1.0f, 1.1f); + b[j] = RandomMat(bw, 0.8f, 0.9f); + } + + for (int j = 0; j < 2; j++) + { + for (int k = 0; k < 2; k++) + { + int ret = test_binaryop(a[j], b[k], flag); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[j], 0.7f, flag); + if (ret != 0) + return ret; + } + } + + return 0; +} + +static int test_binaryop_2() +{ + const int ws[] = {13, 14, 15, 16}; + const int hs[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int h = hs[i]; + const int flag = h == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[4]; + ncnn::Mat b[4]; + for (int j = 0; j < 2; j++) + { + for (int k = 0; k < 2; k++) + { + int bw = j % 2 == 0 ? w : 1; + int bh = k % 2 == 0 ? h : 1; + a[j * 2 + k] = RandomMat(bw, bh, 1.0f, 1.1f); + b[j * 2 + k] = RandomMat(bw, bh, 0.8f, 0.9f); + } + } + + for (int j = 0; j < 4; j++) + { + for (int k = 0; k < 4; k++) + { + int ret = test_binaryop(a[j], b[k], flag); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[j], 0.7f, flag); + if (ret != 0) + return ret; + } + } + + return 0; +} + +static int test_binaryop_3() +{ + const int ws[] = {7, 6, 5, 4}; + const int hs[] = {3, 4, 5, 6}; + const int cs[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int h = hs[i]; + const int c = cs[i]; + const int flag = c == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[8]; + ncnn::Mat b[8]; + for (int j = 0; j < 2; j++) + { + for (int k = 0; k < 2; k++) + { + for (int l = 0; l < 2; l++) + { + int bw = j % 2 == 0 ? w : 1; + int bh = k % 2 == 0 ? h : 1; + int bc = l % 2 == 0 ? c : 1; + a[j * 4 + k * 2 + l] = RandomMat(bw, bh, bc, 1.0f, 1.1f); + b[j * 4 + k * 2 + l] = RandomMat(bw, bh, bc, 0.8f, 0.9f); + } + } + } + + for (int j = 0; j < 8; j++) + { + for (int k = 0; k < 8; k++) + { + int ret = test_binaryop(a[j], b[k], flag); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[j], 0.7f, flag); + if (ret != 0) + return ret; + } + } + + return 0; +} + +static int test_binaryop_4() +{ + const int ws[] = {2, 3, 4, 5}; + const int hs[] = {7, 6, 5, 4}; + const int ds[] = {3, 4, 5, 6}; + const int cs[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int h = hs[i]; + const int d = ds[i]; + const int c = cs[i]; + const int flag = c == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[16]; + ncnn::Mat b[16]; + for (int j = 0; j < 2; j++) + { + for (int k = 0; k < 2; k++) + { + for (int l = 0; l < 2; l++) + { + for (int m = 0; m < 2; m++) + { + int bw = j % 2 == 0 ? w : 1; + int bh = k % 2 == 0 ? h : 1; + int bd = l % 2 == 0 ? d : 1; + int bc = m % 2 == 0 ? c : 1; + a[j * 8 + k * 4 + l * 2 + m] = RandomMat(bw, bh, bd, bc, 1.0f, 1.1f); + b[j * 8 + k * 4 + l * 2 + m] = RandomMat(bw, bh, bd, bc, 0.8f, 0.9f); + } + } + } + } + + for (int j = 0; j < 16; j++) + { + for (int k = 0; k < 16; k++) + { + int ret = test_binaryop(a[j], b[k], flag); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[j], 0.7f, flag); + if (ret != 0) + return ret; + } + } + + return 0; +} + +static int test_binaryop_5() +{ + const int ws[] = {2, 3, 4, 5}; + const int hs[] = {7, 6, 5, 4}; + const int ds[] = {3, 4, 5, 6}; + const int cs[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int h = hs[i]; + const int d = ds[i]; + const int c = cs[i]; + const int flag = c == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[4] = { + RandomMat(c, 1.0f, 1.1f), + RandomMat(d, c, 1.0f, 1.1f), + RandomMat(h, d, c, 1.0f, 1.1f), + RandomMat(w, h, d, c, 1.0f, 1.1f), + }; + + ncnn::Mat b[4] = { + RandomMat(c, 0.8f, 0.9f), + RandomMat(d, c, 0.8f, 0.9f), + RandomMat(h, d, c, 0.8f, 0.9f), + RandomMat(w, h, d, c, 0.8f, 0.9f), + }; + + for (int j = 0; j < 4; j++) + { + for (int k = 0; k < 4; k++) + { + int ret = test_binaryop(a[j], b[k], flag); + if (ret != 0) + return ret; + } + } + } + + return 0; +} + +static int test_binaryop_6() +{ + const int ws[] = {16, 12, 16, 15}; + const int hs[] = {15, 16, 15, 12}; + const int ds[] = {12, 14, 12, 16}; + const int cs[] = {31, 28, 24, 32}; + + for (int i = 0; i < 4; i++) + { + const int w = ws[i]; + const int h = hs[i]; + const int d = ds[i]; + const int c = cs[i]; + const int flag = c == 32 ? TEST_LAYER_DISABLE_GPU_TESTING : 0; + + ncnn::Mat a[3] = { + RandomMat(d, c, 1.0f, 1.1f), + RandomMat(h, d, c, 1.0f, 1.1f), + RandomMat(w, h, d, c, 1.0f, 1.1f), + }; + + for (int j = 0; j < 3; j++) + { + ncnn::Mat b = RandomMat(a[j].w, 0.8f, 0.9f); + + int ret = test_binaryop(a[j], b, flag) || test_binaryop(b, a[j], flag); + if (ret != 0) + return ret; + } + + ncnn::Mat aa[3] = { + RandomMat(c, c, 1.0f, 1.1f), + RandomMat(c, d, c, 1.0f, 1.1f), + RandomMat(c, h, d, c, 1.0f, 1.1f), + }; + + for (int j = 0; j < 3; j++) + { + ncnn::Mat b = RandomMat(aa[j].w, 0.8f, 0.9f); + + int ret = test_binaryop(aa[j], b, flag) || test_binaryop(b, aa[j], flag); + if (ret != 0) + return ret; + } + } + + return 0; +} + +int main() +{ + SRAND(7767517); + + for (op_type = 12; op_type < 19; op_type++) + { + int ret = 0 + || test_binaryop_1() + || test_binaryop_2() + || test_binaryop_3() + || test_binaryop_4() + || test_binaryop_5() + || test_binaryop_6(); + + if (ret != 0) + return ret; + } + + return 0; +} diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index 9978c978978..4c12436ef68 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -203,12 +203,12 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "sub") op_binary->params["0"] = 1; if (t == "mul") op_binary->params["0"] = 2; if (t == "div") op_binary->params["0"] = 3; - if (t == "logaddexp") fprintf(stderr, "BinaryOp logaddexp not supported yet\n"); // TODO + if (t == "logaddexp") op_binary->params["0"] = 14; if (t == "max" || t == "maximum") op_binary->params["0"] = 4; if (t == "min" || t == "minimum") op_binary->params["0"] = 5; - if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO - if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO - if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO + if (t == "floor_divide") op_binary->params["0"] = 15; + if (t == "fmod") op_binary->params["0"] = 12; + if (t == "remainder") op_binary->params["0"] = 17; if (t == "pow") op_binary->params["0"] = 6; if (t == "atan2") op_binary->params["0"] = 10; @@ -218,6 +218,10 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "div") op_binary->params["0"] = 8; if (t == "pow") op_binary->params["0"] = 9; if (t == "atan2") op_binary->params["0"] = 11; + if (t == "fmod") op_binary->params["0"] = 13; + if (t == "logaddexp") op_binary->params["0"] = 14; + if (t == "floor_divide") op_binary->params["0"] = 16; + if (t == "remainder") op_binary->params["0"] = 18; Operand* op_binary_inb = token_is_argument(b) ? op->inputs[std::stoi(b.substr(1))] : graph.get_operand(op->name + "_" + b); op_binary_inb->consumers.push_back(op_binary); From 329484c06b471af59773205cfb04a73d29362a13 Mon Sep 17 00:00:00 2001 From: ihb2032 <40718643+ihb2032@users.noreply.github.com> Date: Sun, 8 Mar 2026 12:46:08 +0800 Subject: [PATCH 07/36] fix: add missing NCNN_MALLOC_OVERREAD padding for MSVC (#6583) Signed-off-by: ihb2032 --- src/allocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/allocator.h b/src/allocator.h index cf8db861495..14ff4d5cab5 100644 --- a/src/allocator.h +++ b/src/allocator.h @@ -56,7 +56,7 @@ static NCNN_FORCEINLINE size_t alignSize(size_t sz, int n) static NCNN_FORCEINLINE void* fastMalloc(size_t size) { #if _MSC_VER - return _aligned_malloc(size, NCNN_MALLOC_ALIGN); + return _aligned_malloc(size + NCNN_MALLOC_OVERREAD, NCNN_MALLOC_ALIGN); #elif (defined(__unix__) || defined(__APPLE__)) && _POSIX_C_SOURCE >= 200112L || (__ANDROID__ && __ANDROID_API__ >= 17) void* ptr = 0; if (posix_memalign(&ptr, NCNN_MALLOC_ALIGN, size + NCNN_MALLOC_OVERREAD)) From bf57baa3216a657a034bd1efb4dce890bbada513 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 8 Mar 2026 14:39:26 +0800 Subject: [PATCH 08/36] x86: add AbsVal_x86 with fp16s and bf16s storage support (#6584) --- src/layer/x86/absval_x86.cpp | 157 +++++++++++++++++++++++++++++++++++ src/layer/x86/absval_x86.h | 24 ++++++ 2 files changed, 181 insertions(+) create mode 100644 src/layer/x86/absval_x86.cpp create mode 100644 src/layer/x86/absval_x86.h diff --git a/src/layer/x86/absval_x86.cpp b/src/layer/x86/absval_x86.cpp new file mode 100644 index 00000000000..4e55ba502d6 --- /dev/null +++ b/src/layer/x86/absval_x86.cpp @@ -0,0 +1,157 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "absval_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "cpu.h" + +namespace ncnn { + +AbsVal_x86::AbsVal_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +int AbsVal_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int elembits = bottom_top_blob.elembits(); + + if (elembits == 16) + return forward_inplace_bf16s_fp16s(bottom_top_blob, opt); + + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int d = bottom_top_blob.d; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + const int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __AVX512F__ + __m512 _sign_mask_avx512 = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff)); + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _mm512_storeu_ps(ptr, _mm512_and_ps(_p, _sign_mask_avx512)); + ptr += 16; + } + if (i < size) + { + const __mmask16 _mask = (__mmask16)((1u << (size - i)) - 1); + __m512 _p = _mm512_maskz_loadu_ps(_mask, ptr); + _mm512_mask_storeu_ps(ptr, _mask, _mm512_and_ps(_p, _sign_mask_avx512)); + } +#else // __AVX512F__ +#if __SSE2__ +#if __AVX__ + __m256 _sign_mask_avx = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _mm256_storeu_ps(ptr, _mm256_and_ps(_p, _sign_mask_avx)); + ptr += 8; + } +#endif // __AVX__ + __m128 _sign_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)); + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_load_ps(ptr); + _mm_store_ps(ptr, _mm_and_ps(_p, _sign_mask)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = *ptr > 0.f ? *ptr : -*ptr; + ptr++; + } +#endif // __AVX512F__ + } + + return 0; +} + +int AbsVal_x86::forward_inplace_bf16s_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int d = bottom_top_blob.d; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + const int size = w * h * d * elempack; + + // fp16/bf16 abs: sign bit is bit 15 for both formats. + // Reinterpret pairs of 16-bit values as float and apply AND with + // 0x7fff7fff to clear both sign bits in one 32-bit operation. + // No fp32 round-trip required, no F16C instructions needed. + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __AVX512F__ + __m512i _sign_mask_avx512 = _mm512_set1_epi32(0x7fff7fff); + for (; i + 31 < size; i += 32) + { + __m512i _p = _mm512_loadu_si512((const __m512i*)ptr); + _mm512_storeu_si512((__m512i*)ptr, _mm512_and_si512(_p, _sign_mask_avx512)); + ptr += 32; + } + if (i < size) + { + const unsigned int remain = size - i; + const __mmask16 _mask = (__mmask16)((1u << ((remain + 1) / 2)) - 1); + __m512i _p = _mm512_maskz_loadu_epi32(_mask, (const __m512i*)ptr); + _mm512_mask_storeu_epi32((__m512i*)ptr, _mask, _mm512_and_si512(_p, _sign_mask_avx512)); + } +#else // __AVX512F__ +#if __SSE2__ +#if __AVX__ + __m256 _sign_mask_avx = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff7fff)); + for (; i + 15 < size; i += 16) + { + __m256 _p = _mm256_castsi256_ps(_mm256_loadu_si256((const __m256i*)ptr)); + _mm256_storeu_si256((__m256i*)ptr, _mm256_castps_si256(_mm256_and_ps(_p, _sign_mask_avx))); + ptr += 16; + } +#endif // __AVX__ + __m128i _sign_mask = _mm_set1_epi32(0x7fff7fff); + for (; i + 7 < size; i += 8) + { + __m128i _p = _mm_load_si128((const __m128i*)ptr); + _mm_store_si128((__m128i*)ptr, _mm_and_si128(_p, _sign_mask)); + ptr += 8; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = *ptr & 0x7fffu; + ptr++; + } +#endif // __AVX512F__ + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/absval_x86.h b/src/layer/x86/absval_x86.h new file mode 100644 index 00000000000..c151a2b6939 --- /dev/null +++ b/src/layer/x86/absval_x86.h @@ -0,0 +1,24 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_ABSVAL_X86_H +#define LAYER_ABSVAL_X86_H + +#include "absval.h" + +namespace ncnn { + +class AbsVal_x86 : public AbsVal +{ +public: + AbsVal_x86(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: + int forward_inplace_bf16s_fp16s(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_ABSVAL_X86_H From 5b66db5c9122edebdf3d47098443d980b01dabaa Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 8 Mar 2026 16:08:59 +0800 Subject: [PATCH 09/36] x86: add LayerNorm_x86 bf16s storage support with avx512bf16 dispatch (#6585) --- src/layer/x86/layernorm_bf16s.h | 489 +++++++++++++++++++++ src/layer/x86/layernorm_x86.cpp | 69 +++ src/layer/x86/layernorm_x86.h | 5 + src/layer/x86/layernorm_x86_avx512bf16.cpp | 17 + src/net.cpp | 12 + tests/testutil.cpp | 24 + 6 files changed, 616 insertions(+) create mode 100644 src/layer/x86/layernorm_bf16s.h create mode 100644 src/layer/x86/layernorm_x86_avx512bf16.cpp diff --git a/src/layer/x86/layernorm_bf16s.h b/src/layer/x86/layernorm_bf16s.h new file mode 100644 index 00000000000..45e79fd61f9 --- /dev/null +++ b/src/layer/x86/layernorm_bf16s.h @@ -0,0 +1,489 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void layernorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack); +#endif + +static void layernorm_bf16s_sse(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + layernorm_bf16s_sse_avx512bf16(ptr, gamma_ptr, beta_ptr, eps, elemcount, elempack); + return; + } +#endif + + const int size = elemcount * elempack; + + // convert bf16 -> fp32, accumulate mean +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _mean_avx512 = _mm512_setzero_ps(); +#endif // __AVX512F__ + __m256 _mean_avx = _mm256_setzero_ps(); +#endif // __AVX__ + __m128 _mean = _mm_setzero_ps(); +#endif // __SSE2__ + float mean = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _mean_avx512 = _mm512_add_ps(_mean_avx512, _p); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _mean_avx = _mm256_add_ps(_mean_avx, _p); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _mean = _mm_add_ps(_mean, _p); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + mean += bfloat16_to_float32(*ptr0++); + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + _mean_avx512 = _mm512_div_ps(_mean_avx512, _elemcount); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); + } +#endif // __AVX512F__ + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + _mean_avx = _mm256_div_ps(_mean_avx, _elemcount); +#if __AVX512F__ + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); + } +#endif // __AVX512F__ + { + __m128 _mean0 = _mm256_castps256_ps128(_mean_avx); + __m128 _mean1 = _mm256_extractf128_ps(_mean_avx, 1); + _mean = _mm_add_ps(_mean, _mean0); + _mean = _mm_add_ps(_mean, _mean1); + } +#endif // __AVX__ + __m128 _elemcount = _mm_set1_ps((float)elemcount); + _mean = _mm_div_ps(_mean, _elemcount); +#if __AVX__ + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + mean += _mm512_comp_reduce_add_ps(_mean_avx512); +#endif // __AVX512F__ + mean += _mm256_reduce_add_ps(_mean_avx); +#endif // __AVX__ + mean += _mm_reduce_add_ps(_mean); +#endif // __SSE2__ + mean = mean / elemcount; +#if __SSE2__ + _mean = _mm_set1_ps(mean); +#if __AVX__ + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + // accumulate var +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _var_avx512 = _mm512_setzero_ps(); +#endif // __AVX512F__ + __m256 _var_avx = _mm256_setzero_ps(); +#endif // __AVX__ + __m128 _var = _mm_setzero_ps(); +#endif // __SSE2__ + float var = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _p = _mm512_sub_ps(_p, _mean_avx512); + _var_avx512 = _mm512_fmadd_ps(_p, _p, _var_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _p = _mm256_sub_ps(_p, _mean_avx); + _var_avx = _mm256_comp_fmadd_ps(_p, _p, _var_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _p = _mm_sub_ps(_p, _mean); + _var = _mm_comp_fmadd_ps(_p, _p, _var); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr0++) - mean; + var += v * v; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + _var_avx512 = _mm512_div_ps(_var_avx512, _elemcount); + _var_avx512 = _mm512_add_ps(_var_avx512, _eps); + __m256 _var0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 0)); + __m256 _var1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 1)); + _var_avx512 = combine8x2_ps(_var0, _var1); + _mean_avx512 = _mm512_mul_ps(_mean_avx512, _var_avx512); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); + } +#endif // __AVX512F__ + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + _var_avx = _mm256_div_ps(_var_avx, _elemcount); + _var_avx = _mm256_add_ps(_var_avx, _eps); + _var_avx = _mm256_rsqrt_ps(_var_avx); + _mean_avx = _mm256_mul_ps(_mean_avx, _var_avx); +#if __AVX512F__ + _var_avx512 = combine8x2_ps(_var_avx, _var_avx); + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); + } +#endif // __AVX512F__ + { + __m128 _var0 = _mm256_castps256_ps128(_var_avx); + __m128 _var1 = _mm256_extractf128_ps(_var_avx, 1); + _var = _mm_add_ps(_var, _var0); + _var = _mm_add_ps(_var, _var1); + } +#endif // __AVX__ + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + _var = _mm_div_ps(_var, _elemcount); + _var = _mm_add_ps(_var, _eps); + _var = _mm_rsqrt_ps(_var); + _mean = _mm_mul_ps(_mean, _var); +#if __AVX__ + _var_avx = combine4x2_ps(_var, _var); + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _var_avx512 = combine8x2_ps(_var_avx, _var_avx); + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + var += _mm512_comp_reduce_add_ps(_var_avx512); +#endif // __AVX512F__ + var += _mm256_reduce_add_ps(_var_avx); +#endif // __AVX__ + var += _mm_reduce_add_ps(_var); +#endif // __SSE2__ + var = 1.f / sqrtf(var / elemcount + eps); + mean = mean * var; +#if __SSE2__ + _var = _mm_set1_ps(var); + _mean = _mm_set1_ps(mean); +#if __AVX__ + _var_avx = combine4x2_ps(_var, _var); + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _var_avx512 = combine8x2_ps(_var_avx, _var_avx); + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + // norm and store bf16 + if (gamma_ptr && beta_ptr) + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + __m512 _beta = _mm512_set1_ps(beta_ptr[0]); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 1; + beta_ptr += 1; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = combine8x2_ps(_gamma0, _gamma1); + __m256 _beta0 = _mm256_set1_ps(beta_ptr[0]); + __m256 _beta1 = _mm256_set1_ps(beta_ptr[1]); + __m512 _beta = combine8x2_ps(_beta0, _beta1); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 2; + beta_ptr += 2; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + __m256 _beta = _mm256_set1_ps(beta_ptr[0]); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 1; + beta_ptr += 1; + } + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m512 _gamma = combine4x4_ps(_gamma0, _gamma1, _gamma2, _gamma3); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m128 _beta2 = _mm_set1_ps(beta_ptr[2]); + __m128 _beta3 = _mm_set1_ps(beta_ptr[3]); + __m512 _beta = combine4x4_ps(_beta0, _beta1, _beta2, _beta3); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 4; + beta_ptr += 4; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = combine4x2_ps(_gamma0, _gamma1); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m256 _beta = combine4x2_ps(_beta0, _beta1); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 2; + beta_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + __m128 _beta = _mm_set1_ps(beta_ptr[0]); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + __m512 _beta = _mm512_loadu_ps(beta_ptr); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 16; + beta_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + __m256 _beta = _mm256_loadu_ps(beta_ptr); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 8; + beta_ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + __m128 _beta = _mm_loadu_ps(beta_ptr); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16((bfloat16_to_float32(*ptr) * var - mean) * *gamma_ptr + *beta_ptr); + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * var - mean); + ptr++; + } + } +} diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index f18ec2562ce..f54e8b61786 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -11,14 +11,22 @@ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "layernorm_bf16s.h" +#endif + LayerNorm_x86::LayerNorm_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void layernorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) @@ -509,6 +517,11 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons const int h = bottom_top_blob.h; const int channels = bottom_top_blob.c; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + if (dims == 1) { // assert affine_size == w @@ -557,4 +570,60 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons return 0; } +#if NCNN_BF16 +int LayerNorm_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int elempack = bottom_top_blob.elempack; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + + if (dims == 1) + { + // assert affine_size == w + unsigned short* ptr = bottom_top_blob; + layernorm_bf16s_sse(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + layernorm_bf16s_sse(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + layernorm_bf16s_sse(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + layernorm_bf16s_sse(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/layernorm_x86.h b/src/layer/x86/layernorm_x86.h index 97ac4fa70d9..f6dd2038e63 100644 --- a/src/layer/x86/layernorm_x86.h +++ b/src/layer/x86/layernorm_x86.h @@ -14,6 +14,11 @@ class LayerNorm_x86 : public LayerNorm LayerNorm_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/layernorm_x86_avx512bf16.cpp b/src/layer/x86/layernorm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..0dd5ad1a630 --- /dev/null +++ b/src/layer/x86/layernorm_x86_avx512bf16.cpp @@ -0,0 +1,17 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "layernorm_bf16s.h" + +void layernorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + layernorm_bf16s_sse(ptr, gamma_ptr, beta_ptr, eps, elemcount, elempack); +} + +} // namespace ncnn diff --git a/src/net.cpp b/src/net.cpp index 25126f52cb0..4394132040e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -455,6 +455,18 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4; +#elif NCNN_AVX512 + if (elemcount % 16 == 0 && ncnn::cpu_support_x86_avx512()) + dst_elempack = 16; + else if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + dst_elempack = 8; + else if (elemcount % 4 == 0) + dst_elempack = 4; +#elif NCNN_AVX + if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + dst_elempack = 8; + else if (elemcount % 4 == 0) + dst_elempack = 4; #elif NCNN_RVV || NCNN_XTHEADVECTOR const int packn = ncnn::cpu_riscv_vlenb() / 2; if (elemcount % packn == 0) diff --git a/tests/testutil.cpp b/tests/testutil.cpp index 2192ddcfe30..878142a2d2c 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -403,6 +403,18 @@ static int convert_to_optimal_layout(const ncnn::Mat& a, ncnn::Mat& a4, ncnn::Ma dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4; +#elif NCNN_AVX512 + if (elemcount % 16 == 0 && ncnn::cpu_support_x86_avx512()) + dst_elempack = 16; + else if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + dst_elempack = 8; + else if (elemcount % 4 == 0) + dst_elempack = 4; +#elif NCNN_AVX + if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + dst_elempack = 8; + else if (elemcount % 4 == 0) + dst_elempack = 4; #elif NCNN_RVV || NCNN_XTHEADVECTOR const int packn = ncnn::cpu_riscv_vlenb() / 2; if (elemcount % packn == 0) @@ -461,6 +473,18 @@ static int convert_to_optimal_layout(const ncnn::Mat& a, ncnn::Mat& a4, ncnn::Ma any_elempack = 4; else if (elemcount % 4 == 0) any_elempack = 1; +#elif NCNN_AVX512 + if (elemcount % 16 == 0 && ncnn::cpu_support_x86_avx512()) + any_elempack = 8; + else if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + any_elempack = 4; + else if (elemcount % 4 == 0) + any_elempack = 1; +#elif NCNN_AVX + if (elemcount % 8 == 0 && ncnn::cpu_support_x86_avx()) + any_elempack = 4; + else if (elemcount % 4 == 0) + any_elempack = 1; #elif NCNN_RVV || NCNN_XTHEADVECTOR const int packn = ncnn::cpu_riscv_vlenb() / 2; if (elemcount % packn == 0) From 12396c8a2a12dfa03404139159c76c78e91155bc Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 8 Mar 2026 20:59:32 +0800 Subject: [PATCH 10/36] x86: add RMSNorm_x86 bf16s storage support with avx512bf16 dispatch (#6586) --- src/layer/x86/rmsnorm_bf16s.h | 327 +++++++++++++++++++++++ src/layer/x86/rmsnorm_x86.cpp | 69 +++++ src/layer/x86/rmsnorm_x86.h | 5 + src/layer/x86/rmsnorm_x86_avx512bf16.cpp | 17 ++ 4 files changed, 418 insertions(+) create mode 100644 src/layer/x86/rmsnorm_bf16s.h create mode 100644 src/layer/x86/rmsnorm_x86_avx512bf16.cpp diff --git a/src/layer/x86/rmsnorm_bf16s.h b/src/layer/x86/rmsnorm_bf16s.h new file mode 100644 index 00000000000..bafe4783b5e --- /dev/null +++ b/src/layer/x86/rmsnorm_bf16s.h @@ -0,0 +1,327 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void rmsnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack); +#endif + +static void rmsnorm_bf16s_sse(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + rmsnorm_bf16s_sse_avx512bf16(ptr, gamma_ptr, eps, elemcount, elempack); + return; + } +#endif + + const int size = elemcount * elempack; + + // accumulate rms +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _rms_avx512 = _mm512_setzero_ps(); +#endif // __AVX512F__ + __m256 _rms_avx = _mm256_setzero_ps(); +#endif // __AVX__ + __m128 _rms = _mm_setzero_ps(); +#endif // __SSE2__ + float rms = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _rms_avx512 = _mm512_fmadd_ps(_p, _p, _rms_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _rms_avx = _mm256_comp_fmadd_ps(_p, _p, _rms_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _rms = _mm_comp_fmadd_ps(_p, _p, _rms); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr0++); + rms += v * v; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + _rms_avx512 = _mm512_div_ps(_rms_avx512, _elemcount); + _rms_avx512 = _mm512_add_ps(_rms_avx512, _eps); + __m256 _rms0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 0)); + __m256 _rms1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 1)); + _rms_avx512 = combine8x2_ps(_rms0, _rms1); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + _rms_avx = _mm256_div_ps(_rms_avx, _elemcount); + _rms_avx = _mm256_add_ps(_rms_avx, _eps); + _rms_avx = _mm256_rsqrt_ps(_rms_avx); +#if __AVX512F__ + _rms_avx512 = combine8x2_ps(_rms_avx, _rms_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + { + __m128 _rms0 = _mm256_castps256_ps128(_rms_avx); + __m128 _rms1 = _mm256_extractf128_ps(_rms_avx, 1); + _rms = _mm_add_ps(_rms, _rms0); + _rms = _mm_add_ps(_rms, _rms1); + } +#endif // __AVX__ + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + _rms = _mm_div_ps(_rms, _elemcount); + _rms = _mm_add_ps(_rms, _eps); + _rms = _mm_rsqrt_ps(_rms); +#if __AVX__ + _rms_avx = combine4x2_ps(_rms, _rms); +#if __AVX512F__ + _rms_avx512 = combine8x2_ps(_rms_avx, _rms_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + rms += _mm512_comp_reduce_add_ps(_rms_avx512); +#endif // __AVX512F__ + rms += _mm256_reduce_add_ps(_rms_avx); +#endif // __AVX__ + rms += _mm_reduce_add_ps(_rms); +#endif // __SSE2__ + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __SSE2__ + _rms = _mm_set1_ps(rms); +#if __AVX__ + _rms_avx = combine4x2_ps(_rms, _rms); +#if __AVX512F__ + _rms_avx512 = combine8x2_ps(_rms_avx, _rms_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + // normalize and store bf16 + if (gamma_ptr) + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 1; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = combine8x2_ps(_gamma0, _gamma1); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 2; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 1; + } + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m512 _gamma = combine4x4_ps(_gamma0, _gamma1, _gamma2, _gamma3); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 4; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = combine4x2_ps(_gamma0, _gamma1); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + gamma_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + gamma_ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * rms * *gamma_ptr); + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_mul_ps(_p, _rms_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_mul_ps(_p, _rms_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _rms); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * rms); + ptr++; + } + } +} diff --git a/src/layer/x86/rmsnorm_x86.cpp b/src/layer/x86/rmsnorm_x86.cpp index 51524e9ee64..9975a66c216 100644 --- a/src/layer/x86/rmsnorm_x86.cpp +++ b/src/layer/x86/rmsnorm_x86.cpp @@ -11,14 +11,22 @@ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "rmsnorm_bf16s.h" +#endif + RMSNorm_x86::RMSNorm_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) @@ -349,6 +357,11 @@ int RMSNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const const int channels = bottom_top_blob.c; const int elempack = bottom_top_blob.elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + if (dims == 1) { // assert affine_size == w @@ -397,4 +410,60 @@ int RMSNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int RMSNorm_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + unsigned short* ptr = bottom_top_blob; + rmsnorm_bf16s_sse(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + rmsnorm_bf16s_sse(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm_bf16s_sse(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + rmsnorm_bf16s_sse(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/rmsnorm_x86.h b/src/layer/x86/rmsnorm_x86.h index 70bc164ff67..82f909336dc 100644 --- a/src/layer/x86/rmsnorm_x86.h +++ b/src/layer/x86/rmsnorm_x86.h @@ -14,6 +14,11 @@ class RMSNorm_x86 : public RMSNorm RMSNorm_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/rmsnorm_x86_avx512bf16.cpp b/src/layer/x86/rmsnorm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..2a92a694be5 --- /dev/null +++ b/src/layer/x86/rmsnorm_x86_avx512bf16.cpp @@ -0,0 +1,17 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "rmsnorm_bf16s.h" + +void rmsnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + rmsnorm_bf16s_sse(ptr, gamma_ptr, eps, elemcount, elempack); +} + +} // namespace ncnn From f28502757e4f8f170ddb9d3b01b8016d49fa8856 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Mar 2026 12:20:56 +0800 Subject: [PATCH 11/36] x86: add UnaryOp_x86 bf16s storage support (#6588) --- src/layer/x86/unaryop_bf16s.h | 75 +++ src/layer/x86/unaryop_functor.h | 534 +++++++++++++++++++ src/layer/x86/unaryop_x86.cpp | 648 ++++------------------- src/layer/x86/unaryop_x86.h | 5 + src/layer/x86/unaryop_x86_avx512bf16.cpp | 111 ++++ 5 files changed, 838 insertions(+), 535 deletions(-) create mode 100644 src/layer/x86/unaryop_bf16s.h create mode 100644 src/layer/x86/unaryop_functor.h create mode 100644 src/layer/x86/unaryop_x86_avx512bf16.cpp diff --git a/src/layer/x86/unaryop_bf16s.h b/src/layer/x86/unaryop_bf16s.h new file mode 100644 index 00000000000..f43ee05271c --- /dev/null +++ b/src/layer/x86/unaryop_bf16s.h @@ -0,0 +1,75 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +template +static int unary_op_inplace_bf16s(Mat& a, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = op.func_pack16(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = op.func_pack16(_p); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = op.func_pack8(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = op.func_pack4(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = op.func_pack4(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(op.func(bfloat16_to_float32(*ptr))); + ptr++; + } + } + + return 0; +} diff --git a/src/layer/x86/unaryop_functor.h b/src/layer/x86/unaryop_functor.h new file mode 100644 index 00000000000..9ccbfc8a782 --- /dev/null +++ b/src/layer/x86/unaryop_functor.h @@ -0,0 +1,534 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +struct unary_op_abs +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)fabsf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return abs_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return abs256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return abs512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_neg +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return -x; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_sub_ps(_mm_setzero_ps(), x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_sub_ps(_mm256_setzero_ps(), x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_sub_ps(_mm512_setzero_ps(), x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_floor +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)floorf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return floor_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_floor_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_roundscale_ps(x, _MM_FROUND_TO_NEG_INF); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_ceil +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)ceilf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return ceil_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_ceil_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_roundscale_ps(x, _MM_FROUND_TO_POS_INF); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_square +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return x * x; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_mul_ps(x, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_mul_ps(x, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_mul_ps(x, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_sqrt +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)sqrtf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_sqrt_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_sqrt_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_sqrt_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_rsqrt +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return 1.f / sqrtf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_rsqrt_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_rsqrt_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + __m256 _x0 = _mm512_extractf32x8_ps(x, 0); + __m256 _x1 = _mm512_extractf32x8_ps(x, 1); + _x0 = _mm256_rsqrt_ps(_x0); + _x1 = _mm256_rsqrt_ps(_x1); + return combine8x2_ps(_x0, _x1); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_exp +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)expf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return exp_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return exp256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return exp512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_log +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)logf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return log_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return log256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return log512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_sin +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)sinf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return sin_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return sin256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return sin512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_cos +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)cosf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return cos_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return cos256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return cos512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_tan +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)tanf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return tan_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return tan256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return tan512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_asin +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)asinf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return asin_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return asin256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return asin512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_acos +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)acosf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return acos_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return acos256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return acos512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_atan +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)atanf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return atan_ps(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return atan256_ps(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return atan512_ps(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_reciprocal +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return 1.f / x; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_div_ps(*(__m128*)_ps_1, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_div_ps(*(__m256*)_ps256_1, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_div_ps(*(__m512*)_ps512_1, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_tanh +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)tanhf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return tanh_sse(x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return tanh_avx(x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return tanh_avx512(x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_log10 +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)log10f(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { + return _mm_mul_ps(log_ps(x), _mm_set1_ps(0.434294481903)); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_mul_ps(log256_ps(x), _mm256_set1_ps(0.434294481903)); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_mul_ps(log512_ps(x), _mm512_set1_ps(0.434294481903)); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_round +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return nearbyintf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { +#if __SSE4_1__ + return _mm_round_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#else + return _mm_cvtepi32_ps(_mm_cvtps_epi32(x)); +#endif + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_round_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_roundscale_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct unary_op_trunc +{ + NCNN_FORCEINLINE float func(const float& x) const + { + return (float)truncf(x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const + { +#if __SSE4_1__ + return _mm_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); +#else + return _mm_cvtepi32_ps(_mm_cvttps_epi32(x)); +#endif + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const + { + return _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const + { + return _mm512_roundscale_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 524c509b555..b41bde42edf 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -22,14 +22,28 @@ #endif // __SSE2__ #include "x86_usability.h" #include "x86_activation.h" +#include "cpu.h" namespace ncnn { +namespace UnaryOp_x86_functor { + +#include "unaryop_functor.h" + +} // namespace UnaryOp_x86_functor + +#if NCNN_BF16 +#include "unaryop_bf16s.h" +#endif + UnaryOp_x86::UnaryOp_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } template @@ -96,544 +110,13 @@ static int unary_op_inplace(Mat& a, const Option& opt) return 0; } -namespace UnaryOp_x86_functor { -struct unary_op_abs -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)fabsf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return abs_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return abs256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return abs512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_neg -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return -x; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_sub_ps(_mm_setzero_ps(), x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_sub_ps(_mm256_setzero_ps(), x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_sub_ps(_mm512_setzero_ps(), x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_floor -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)floorf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return floor_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_floor_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_roundscale_ps(x, _MM_FROUND_TO_NEG_INF); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_ceil -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)ceilf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return ceil_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_ceil_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_roundscale_ps(x, _MM_FROUND_TO_POS_INF); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_square -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return x * x; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_mul_ps(x, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_mul_ps(x, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_mul_ps(x, x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_sqrt -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)sqrtf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_sqrt_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_sqrt_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_sqrt_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_rsqrt -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return 1.f / sqrtf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_rsqrt_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_rsqrt_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - __m256 _x0 = _mm512_extractf32x8_ps(x, 0); - __m256 _x1 = _mm512_extractf32x8_ps(x, 1); - _x0 = _mm256_rsqrt_ps(_x0); - _x1 = _mm256_rsqrt_ps(_x1); - return combine8x2_ps(_x0, _x1); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_exp -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)expf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return exp_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return exp256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return exp512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_log -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)logf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return log_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return log256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return log512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_sin -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)sinf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return sin_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return sin256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return sin512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_cos -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)cosf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return cos_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return cos256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return cos512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_tan -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)tanf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return tan_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return tan256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return tan512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_asin -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)asinf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return asin_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return asin256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return asin512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_acos -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)acosf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return acos_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return acos256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return acos512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_atan -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)atanf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return atan_ps(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return atan256_ps(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return atan512_ps(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_reciprocal -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return 1.f / x; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_div_ps(*(__m128*)_ps_1, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_div_ps(*(__m256*)_ps256_1, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_div_ps(*(__m512*)_ps512_1, x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_tanh -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)tanhf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return tanh_sse(x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return tanh_avx(x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return tanh_avx512(x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_log10 -{ - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)log10f(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { - return _mm_mul_ps(log_ps(x), _mm_set1_ps(0.434294481903)); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_mul_ps(log256_ps(x), _mm256_set1_ps(0.434294481903)); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_mul_ps(log512_ps(x), _mm512_set1_ps(0.434294481903)); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_round -{ - NCNN_FORCEINLINE float func(const float& x) const - { - // return (x + 12582912.f) - 12582912.f; - return nearbyintf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { -#if __SSE4_1__ - return _mm_round_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); -#else - return _mm_cvtepi32_ps(_mm_cvtps_epi32(x)); -#endif - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_round_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_roundscale_ps(x, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -struct unary_op_trunc +int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - NCNN_FORCEINLINE float func(const float& x) const - { - return (float)truncf(x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x) const - { -#if __SSE4_1__ - return _mm_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); -#else - return _mm_cvtepi32_ps(_mm_cvttps_epi32(x)); +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); #endif - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x) const - { - return _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x) const - { - return _mm512_roundscale_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; - -} // namespace UnaryOp_x86_functor -int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const -{ using namespace UnaryOp_x86_functor; if (op_type == Operation_ABS) return unary_op_inplace(bottom_top_blob, opt); @@ -709,4 +192,99 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +int unaryop_bf16s_sse_avx512bf16(Mat& bottom_top_blob, int op_type, const Option& opt); +#endif + +static int unaryop_bf16s_sse(Mat& bottom_top_blob, int op_type, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + return unaryop_bf16s_sse_avx512bf16(bottom_top_blob, op_type, opt); + } +#endif + + using namespace UnaryOp_x86_functor; + if (op_type == UnaryOp::Operation_ABS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_NEG) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_FLOOR) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_CEIL) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SQUARE) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SQRT) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_RSQRT) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_EXP) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_LOG) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SIN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_COS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_TAN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ASIN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ACOS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ATAN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_RECIPROCAL) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_TANH) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_LOG10) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ROUND) + { + // round to nearest even +#ifdef FE_TONEAREST + int old_rm = fegetround(); + fesetround(FE_TONEAREST); +#endif + int ret = unary_op_inplace_bf16s(bottom_top_blob, opt); +#ifdef FE_TONEAREST + fesetround(old_rm); +#endif + return ret; + } + + if (op_type == UnaryOp::Operation_TRUNC) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + return 0; +} + +int UnaryOp_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + return unaryop_bf16s_sse(bottom_top_blob, op_type, opt); +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/unaryop_x86.h b/src/layer/x86/unaryop_x86.h index dae07ba5042..c3b5fcd7905 100644 --- a/src/layer/x86/unaryop_x86.h +++ b/src/layer/x86/unaryop_x86.h @@ -14,6 +14,11 @@ class UnaryOp_x86 : public UnaryOp UnaryOp_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/unaryop_x86_avx512bf16.cpp b/src/layer/x86/unaryop_x86_avx512bf16.cpp new file mode 100644 index 00000000000..9707ce82722 --- /dev/null +++ b/src/layer/x86/unaryop_x86_avx512bf16.cpp @@ -0,0 +1,111 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "unaryop_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __SSE4_1__ +#include +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE4_1__ +#endif // __SSE2__ + +#include "x86_usability.h" +#include "x86_activation.h" + +namespace ncnn { + +namespace UnaryOp_x86_functor { + +#include "unaryop_functor.h" + +} // namespace UnaryOp_x86_functor + +#include "unaryop_bf16s.h" + +int unaryop_bf16s_sse_avx512bf16(Mat& bottom_top_blob, int op_type, const Option& opt) +{ + using namespace UnaryOp_x86_functor; + if (op_type == UnaryOp::Operation_ABS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_NEG) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_FLOOR) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_CEIL) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SQUARE) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SQRT) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_RSQRT) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_EXP) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_LOG) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_SIN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_COS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_TAN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ASIN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ACOS) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ATAN) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_RECIPROCAL) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_TANH) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_LOG10) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + if (op_type == UnaryOp::Operation_ROUND) + { + // round to nearest even +#ifdef FE_TONEAREST + int old_rm = fegetround(); + fesetround(FE_TONEAREST); +#endif + int ret = unary_op_inplace_bf16s(bottom_top_blob, opt); +#ifdef FE_TONEAREST + fesetround(old_rm); +#endif + return ret; + } + + if (op_type == UnaryOp::Operation_TRUNC) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + + return 0; +} + +} // namespace ncnn From 55e1948172fd097c1d198a75357883b8604d0b29 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Mar 2026 12:21:09 +0800 Subject: [PATCH 12/36] clip relu sigmoid x86 bf16s (#6589) --- src/layer/x86/clip_bf16s.h | 98 +++++++++++++ src/layer/x86/clip_x86.cpp | 27 +++- src/layer/x86/clip_x86.h | 5 + src/layer/x86/clip_x86_avx512bf16.cpp | 20 +++ src/layer/x86/relu_bf16s.h | 169 +++++++++++++++++++++++ src/layer/x86/relu_x86.cpp | 27 +++- src/layer/x86/relu_x86.h | 3 + src/layer/x86/relu_x86_avx512bf16.cpp | 20 +++ src/layer/x86/sigmoid_bf16s.h | 92 ++++++++++++ src/layer/x86/sigmoid_x86.cpp | 25 ++++ src/layer/x86/sigmoid_x86.h | 5 + src/layer/x86/sigmoid_x86_avx512bf16.cpp | 32 +++++ 12 files changed, 521 insertions(+), 2 deletions(-) create mode 100644 src/layer/x86/clip_bf16s.h create mode 100644 src/layer/x86/clip_x86_avx512bf16.cpp create mode 100644 src/layer/x86/relu_bf16s.h create mode 100644 src/layer/x86/relu_x86_avx512bf16.cpp create mode 100644 src/layer/x86/sigmoid_bf16s.h create mode 100644 src/layer/x86/sigmoid_x86_avx512bf16.cpp diff --git a/src/layer/x86/clip_bf16s.h b/src/layer/x86/clip_bf16s.h new file mode 100644 index 00000000000..c558e8924eb --- /dev/null +++ b/src/layer/x86/clip_bf16s.h @@ -0,0 +1,98 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void clip_bf16s_avx512bf16(Mat& a, float min, float max, const Option& opt); +#endif + +static void clip_bf16s(Mat& a, float min, float max, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + clip_bf16s_avx512bf16(a, min, max, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _min_avx512 = _mm512_set1_ps(min); + __m512 _max_avx512 = _mm512_set1_ps(max); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_max_ps(_p, _min_avx512); + _p = _mm512_min_ps(_p, _max_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = _mm512_max_ps(_p, _min_avx512); + _p = _mm512_min_ps(_p, _max_avx512); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _min_avx = _mm256_set1_ps(min); + __m256 _max_avx = _mm256_set1_ps(max); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_max_ps(_p, _min_avx); + _p = _mm256_min_ps(_p, _max_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _min = _mm_set1_ps(min); + __m128 _max = _mm_set1_ps(max); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_max_ps(_p, _min); + _p = _mm_min_ps(_p, _max); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _min = _mm_set1_ps(min); + __m128 _max = _mm_set1_ps(max); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_max_ps(_p, _min); + _p = _mm_min_ps(_p, _max); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < min) v = min; + if (v > max) v = max; + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/clip_x86.cpp b/src/layer/x86/clip_x86.cpp index 41e50be7ec7..5c781f47dda 100644 --- a/src/layer/x86/clip_x86.cpp +++ b/src/layer/x86/clip_x86.cpp @@ -10,18 +10,34 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "clip_bf16s.h" +#endif + Clip_x86::Clip_x86() { #if __SSE2__ support_packing = true; support_any_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Clip_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + const int w = bottom_top_blob.w; const int h = bottom_top_blob.h; const int d = bottom_top_blob.d; @@ -94,4 +110,13 @@ int Clip_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } -} //namespace ncnn +#if NCNN_BF16 +int Clip_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + clip_bf16s(bottom_top_blob, min, max, opt); + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/x86/clip_x86.h b/src/layer/x86/clip_x86.h index 07b806c2b97..e6ece4e0f49 100644 --- a/src/layer/x86/clip_x86.h +++ b/src/layer/x86/clip_x86.h @@ -14,6 +14,11 @@ class Clip_x86 : public Clip Clip_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/clip_x86_avx512bf16.cpp b/src/layer/x86/clip_x86_avx512bf16.cpp new file mode 100644 index 00000000000..3641ef761c8 --- /dev/null +++ b/src/layer/x86/clip_x86_avx512bf16.cpp @@ -0,0 +1,20 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "clip_x86.h" + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "clip_bf16s.h" + +void clip_bf16s_avx512bf16(Mat& a, float min, float max, const Option& opt) +{ + clip_bf16s(a, min, max, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/relu_bf16s.h b/src/layer/x86/relu_bf16s.h new file mode 100644 index 00000000000..30b8fafa6dc --- /dev/null +++ b/src/layer/x86/relu_bf16s.h @@ -0,0 +1,169 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void relu_bf16s_avx512bf16(Mat& a, float slope, const Option& opt); +#endif + +static void relu_bf16s(Mat& a, float slope, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + relu_bf16s_avx512bf16(a, slope, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + if (slope == 0.f) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _zero_avx512 = _mm512_setzero_ps(); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_max_ps(_p, _zero_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = _mm512_max_ps(_p, _zero_avx512); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _zero_avx = _mm256_setzero_ps(); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_max_ps(_p, _zero_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_max_ps(_p, _zero); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_max_ps(_p, _zero); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < 0.f) v = 0.f; + *ptr = float32_to_bfloat16(v); + ptr++; + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _zero_avx512 = _mm512_setzero_ps(); + __m512 _slope_avx512 = _mm512_set1_ps(slope); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __mmask16 _is_negative = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_LT_OQ); + _p = _mm512_mask_mul_ps(_p, _is_negative, _p, _slope_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + __mmask16 _is_negative = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_LT_OQ); + _p = _mm512_mask_mul_ps(_p, _is_negative, _p, _slope_avx512); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _zero_avx = _mm256_setzero_ps(); + __m256 _slope_avx = _mm256_set1_ps(slope); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _pos = _mm256_max_ps(_zero_avx, _p); + __m256 _neg = _mm256_min_ps(_zero_avx, _p); + _p = _mm256_add_ps(_pos, _mm256_mul_ps(_slope_avx, _neg)); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _zero = _mm_setzero_ps(); + __m128 _slope = _mm_set1_ps(slope); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _pos = _mm_max_ps(_zero, _p); + __m128 _neg = _mm_min_ps(_zero, _p); + _p = _mm_add_ps(_pos, _mm_mul_ps(_slope, _neg)); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _zero = _mm_setzero_ps(); + __m128 _slope = _mm_set1_ps(slope); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _pos = _mm_max_ps(_zero, _p); + __m128 _neg = _mm_min_ps(_zero, _p); + _p = _mm_add_ps(_pos, _mm_mul_ps(_slope, _neg)); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < 0.f) v *= slope; + *ptr = float32_to_bfloat16(v); + ptr++; + } + } + } +} diff --git a/src/layer/x86/relu_x86.cpp b/src/layer/x86/relu_x86.cpp index 014fa3995fc..7b9bf7fd6d9 100644 --- a/src/layer/x86/relu_x86.cpp +++ b/src/layer/x86/relu_x86.cpp @@ -10,13 +10,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "relu_bf16s.h" +#endif + ReLU_x86::ReLU_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int ReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -26,6 +37,11 @@ int ReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (elembits == 8) return forward_inplace_int8(bottom_top_blob, opt); +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + int w = bottom_top_blob.w; int h = bottom_top_blob.h; int d = bottom_top_blob.d; @@ -210,4 +226,13 @@ int ReLU_x86::forward_inplace_int8(Mat& bottom_top_blob, const Option& opt) cons return 0; } -} //namespace ncnn +#if NCNN_BF16 +int ReLU_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + relu_bf16s(bottom_top_blob, slope, opt); + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/x86/relu_x86.h b/src/layer/x86/relu_x86.h index 49139812f47..fe3d82ddf73 100644 --- a/src/layer/x86/relu_x86.h +++ b/src/layer/x86/relu_x86.h @@ -17,6 +17,9 @@ class ReLU_x86 : public ReLU protected: int forward_inplace_int8(Mat& bottom_top_blob, const Option& opt) const; +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/relu_x86_avx512bf16.cpp b/src/layer/x86/relu_x86_avx512bf16.cpp new file mode 100644 index 00000000000..7cd9976329e --- /dev/null +++ b/src/layer/x86/relu_x86_avx512bf16.cpp @@ -0,0 +1,20 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "relu_x86.h" + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "relu_bf16s.h" + +void relu_bf16s_avx512bf16(Mat& a, float slope, const Option& opt) +{ + relu_bf16s(a, slope, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/sigmoid_bf16s.h b/src/layer/x86/sigmoid_bf16s.h new file mode 100644 index 00000000000..ffc78f7259d --- /dev/null +++ b/src/layer/x86/sigmoid_bf16s.h @@ -0,0 +1,92 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void sigmoid_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void sigmoid_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + sigmoid_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _one_avx512 = _mm512_set1_ps(1.f); + __m512 _zero_avx512 = _mm512_setzero_ps(); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_div_ps(_one_avx512, _mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _p)))); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = _mm512_div_ps(_one_avx512, _mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _p)))); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _one_avx = _mm256_set1_ps(1.f); + __m256 _zero_avx = _mm256_setzero_ps(); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_div_ps(_one_avx, _mm256_add_ps(_one_avx, exp256_ps(_mm256_sub_ps(_zero_avx, _p)))); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _one = _mm_set1_ps(1.f); + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_div_ps(_one, _mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _p)))); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _one = _mm_set1_ps(1.f); + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_div_ps(_one, _mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _p)))); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = 1.f / (1.f + expf(-v)); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/sigmoid_x86.cpp b/src/layer/x86/sigmoid_x86.cpp index dee0bbe5fdf..66284cec074 100644 --- a/src/layer/x86/sigmoid_x86.cpp +++ b/src/layer/x86/sigmoid_x86.cpp @@ -15,17 +15,32 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "sigmoid_bf16s.h" +#endif + Sigmoid_x86::Sigmoid_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Sigmoid_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif int w = bottom_top_blob.w; int h = bottom_top_blob.h; int d = bottom_top_blob.d; @@ -91,4 +106,14 @@ int Sigmoid_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int Sigmoid_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + sigmoid_bf16s(bottom_top_blob, opt); + + return 0; +} + +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/sigmoid_x86.h b/src/layer/x86/sigmoid_x86.h index 2f4493fb01f..673a6dead02 100644 --- a/src/layer/x86/sigmoid_x86.h +++ b/src/layer/x86/sigmoid_x86.h @@ -14,6 +14,11 @@ class Sigmoid_x86 : public Sigmoid Sigmoid_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/sigmoid_x86_avx512bf16.cpp b/src/layer/x86/sigmoid_x86_avx512bf16.cpp new file mode 100644 index 00000000000..0b86d15bdef --- /dev/null +++ b/src/layer/x86/sigmoid_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "sigmoid_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "sigmoid_bf16s.h" + +void sigmoid_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + sigmoid_bf16s(a, opt); +} + +} // namespace ncnn From 723fc18974ab292e6bb6feb131ea8c22c6c4f6ec Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Mar 2026 13:04:47 +0800 Subject: [PATCH 13/36] drop virtual inheritance (#6590) --- src/layer/arm/groupnorm_arm.h | 2 +- src/layer/riscv/shufflechannel_riscv.h | 2 +- src/layer/vulkan/dequantize_vulkan.h | 2 +- src/layer/vulkan/quantize_vulkan.h | 2 +- src/layer/vulkan/requantize_vulkan.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layer/arm/groupnorm_arm.h b/src/layer/arm/groupnorm_arm.h index 9037cc42608..10706ae9cf3 100644 --- a/src/layer/arm/groupnorm_arm.h +++ b/src/layer/arm/groupnorm_arm.h @@ -8,7 +8,7 @@ namespace ncnn { -class GroupNorm_arm : virtual public GroupNorm +class GroupNorm_arm : public GroupNorm { public: GroupNorm_arm(); diff --git a/src/layer/riscv/shufflechannel_riscv.h b/src/layer/riscv/shufflechannel_riscv.h index 499f0b99b26..501221f9403 100644 --- a/src/layer/riscv/shufflechannel_riscv.h +++ b/src/layer/riscv/shufflechannel_riscv.h @@ -8,7 +8,7 @@ namespace ncnn { -class ShuffleChannel_riscv : virtual public ShuffleChannel +class ShuffleChannel_riscv : public ShuffleChannel { public: ShuffleChannel_riscv(); diff --git a/src/layer/vulkan/dequantize_vulkan.h b/src/layer/vulkan/dequantize_vulkan.h index d0f318b3466..c22b48eefb0 100644 --- a/src/layer/vulkan/dequantize_vulkan.h +++ b/src/layer/vulkan/dequantize_vulkan.h @@ -8,7 +8,7 @@ namespace ncnn { -class Dequantize_vulkan : virtual public Dequantize +class Dequantize_vulkan : public Dequantize { public: Dequantize_vulkan(); diff --git a/src/layer/vulkan/quantize_vulkan.h b/src/layer/vulkan/quantize_vulkan.h index e94d366709b..7c6d12b65f2 100644 --- a/src/layer/vulkan/quantize_vulkan.h +++ b/src/layer/vulkan/quantize_vulkan.h @@ -8,7 +8,7 @@ namespace ncnn { -class Quantize_vulkan : virtual public Quantize +class Quantize_vulkan : public Quantize { public: Quantize_vulkan(); diff --git a/src/layer/vulkan/requantize_vulkan.h b/src/layer/vulkan/requantize_vulkan.h index a9f7142b9a5..909bb6d7017 100644 --- a/src/layer/vulkan/requantize_vulkan.h +++ b/src/layer/vulkan/requantize_vulkan.h @@ -8,7 +8,7 @@ namespace ncnn { -class Requantize_vulkan : virtual public Requantize +class Requantize_vulkan : public Requantize { public: Requantize_vulkan(); From b1ce9f95481a7d514e628767d98246ebac4b87cd Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Mar 2026 15:21:11 +0800 Subject: [PATCH 14/36] x86: add BinaryOp_x86 bf16s storage support with avx512bf16 dispatch (#6591) --- src/layer/x86/binaryop_bf16s.h | 333 ++++++++ src/layer/x86/binaryop_functor.h | 498 ++++++++++++ src/layer/x86/binaryop_x86.cpp | 898 +++++++++------------- src/layer/x86/binaryop_x86.h | 6 + src/layer/x86/binaryop_x86_avx512bf16.cpp | 57 ++ 5 files changed, 1272 insertions(+), 520 deletions(-) create mode 100644 src/layer/x86/binaryop_bf16s.h create mode 100644 src/layer/x86/binaryop_functor.h create mode 100644 src/layer/x86/binaryop_x86_avx512bf16.cpp diff --git a/src/layer/x86/binaryop_bf16s.h b/src/layer/x86/binaryop_bf16s.h new file mode 100644 index 00000000000..218d8812ffa --- /dev/null +++ b/src/layer/x86/binaryop_bf16s.h @@ -0,0 +1,333 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +template +static void binary_op_vector_no_broadcast_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int size) +{ + const Op op; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _b = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_p, _b)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr += 16; + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _b = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_p, _b)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr += 8; + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _b = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + __m128 _outp4 = op.func_pack4(_p, _b); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr += 4; + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = float32_to_bfloat16(op.func(bfloat16_to_float32(*ptr++), bfloat16_to_float32(*ptr1++))); + } +} + +template +static void binary_op_vector_broadcast_b_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int size, int elempack) +{ + const Op op; + + const float b = bfloat16_to_float32(*ptr1); + + int i = 0; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)) : _mm_set1_ps(bfloat16_to_float32((short)*ptr1)); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)) : combine4x2_ps(_b_128, _b_128); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)) : combine8x2_ps(_b_256, _b_256); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_p, _b_512)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_p, _b_256)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _outp4 = op.func_pack4(_p, _b_128); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = float32_to_bfloat16(op.func(bfloat16_to_float32(*ptr++), b)); + } +} + +template +static void binary_op_vector_broadcast_a_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int size, int elempack) +{ + const Op op; + + const float a = bfloat16_to_float32(*ptr); + + int i = 0; +#if __SSE2__ + __m128 _a_128 = (elempack == 4) ? bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)) : _mm_set1_ps(bfloat16_to_float32((short)*ptr)); +#if __AVX__ + __m256 _a_256 = (elempack == 8) ? bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)) : combine4x2_ps(_a_128, _a_128); +#if __AVX512F__ + __m512 _a_512 = (elempack == 16) ? bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)) : combine8x2_ps(_a_256, _a_256); + for (; i + 15 < size; i += 16) + { + __m512 _b = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_a_512, _b)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _b = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_a_256, _b)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _b = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + __m128 _outp4 = op.func_pack4(_a_128, _b); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = float32_to_bfloat16(op.func(a, bfloat16_to_float32(*ptr1++))); + } +} + +template +static void binary_op_vector_broadcast_pb_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int w, int elempack) +{ + const Op op; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (int i = 0; i < w; i++) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _b = _mm512_set1_ps(bfloat16_to_float32(*ptr1)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_p, _b)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr += 16; + ptr1 += 1; + outptr += 16; + } + return; + } +#endif // __AVX512F__ + if (elempack == 8) + { + for (int i = 0; i < w; i++) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _b = _mm256_set1_ps(bfloat16_to_float32(*ptr1)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_p, _b)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr += 8; + ptr1 += 1; + outptr += 8; + } + return; + } +#endif // __AVX__ + if (elempack == 4) + { + for (int i = 0; i < w; i++) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _b = _mm_set1_ps(bfloat16_to_float32(*ptr1)); + __m128 _outp4 = op.func_pack4(_p, _b); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr += 4; + ptr1 += 1; + outptr += 4; + } + return; + } +#endif // __SSE2__ +} + +template +static void binary_op_vector_broadcast_pb_b_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int w, int elempack) +{ + const Op op; + + const int size = w * elempack; + const float b = bfloat16_to_float32(*ptr1); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _b_avx512 = _mm512_set1_ps(b); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_p, _b_avx512)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr += 16; + outptr += 16; + } + } +#endif // __AVX512F__ + { + __m256 _b_avx = _mm256_set1_ps(b); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_p, _b_avx)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr += 8; + outptr += 8; + } + } +#endif // __AVX__ + { + __m128 _b_sse = _mm_set1_ps(b); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _outp4 = op.func_pack4(_p, _b_sse); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr += 4; + outptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = float32_to_bfloat16(op.func(bfloat16_to_float32(*ptr++), b)); + } +} + +template +static void binary_op_vector_broadcast_pb_a_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int w, int elempack) +{ + const Op op; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + for (int i = 0; i < w; i++) + { + __m512 _b = _mm512_set1_ps(bfloat16_to_float32(*ptr1)); + __m256i _outp = float2bfloat_avx512(op.func_pack16(_p, _b)); + _mm256_storeu_si256((__m256i*)outptr, _outp); + ptr1 += 1; + outptr += 16; + } + return; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + for (int i = 0; i < w; i++) + { + __m256 _b = _mm256_set1_ps(bfloat16_to_float32(*ptr1)); + __m128i _outp = float2bfloat_avx(op.func_pack8(_p, _b)); + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr1 += 1; + outptr += 8; + } + return; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + for (int i = 0; i < w; i++) + { + __m128 _b = _mm_set1_ps(bfloat16_to_float32(*ptr1)); + __m128 _outp4 = op.func_pack4(_p, _b); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_outp4, _outp4)); + ptr1 += 1; + outptr += 4; + } + return; + } +#endif // __SSE2__ +} + +template +static void binary_op_vector_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int aw, int bw, int ap, int bp) +{ + const int w = std::max(aw, bw); + const int elempack = std::max(ap, bp); + const int size = w * elempack; + + if (ap == bp) + { + if (aw == bw) + return binary_op_vector_no_broadcast_bf16s(ptr, ptr1, outptr, size); + if (bw == 1) + return binary_op_vector_broadcast_b_bf16s(ptr, ptr1, outptr, size, elempack); + if (aw == 1) + return binary_op_vector_broadcast_a_bf16s(ptr, ptr1, outptr, size, elempack); + } + + if (bp == 1) + { + if (aw == bw) + return binary_op_vector_broadcast_pb_bf16s(ptr, ptr1, outptr, w, elempack); + if (bw == 1) + return binary_op_vector_broadcast_pb_b_bf16s(ptr, ptr1, outptr, w, elempack); + if (aw == 1) + return binary_op_vector_broadcast_pb_a_bf16s(ptr, ptr1, outptr, w, elempack); + } +} diff --git a/src/layer/x86/binaryop_functor.h b/src/layer/x86/binaryop_functor.h new file mode 100644 index 00000000000..729b13ed0fe --- /dev/null +++ b/src/layer/x86/binaryop_functor.h @@ -0,0 +1,498 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +struct binary_op_add +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return x + y; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_add_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_add_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_add_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_sub +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return x - y; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_sub_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_sub_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_sub_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_mul +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return x * y; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_mul_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_mul_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_mul_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_div +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return x / y; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_div_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_div_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_div_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_max +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return std::max(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_max_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_max_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_max_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_min +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return std::min(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_min_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_min_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_min_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_pow +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)powf(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return pow_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return pow256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return pow512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rsub +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return y - x; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_sub_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_sub_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_sub_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rdiv +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return y / x; + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return _mm_div_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return _mm256_div_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return _mm512_div_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rpow +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)powf(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return pow_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return pow256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return pow512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_atan2 +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)atan2f(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return atan2_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return atan2256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return atan2512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_ratan2 +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)atan2f(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return atan2_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return atan2256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return atan2512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_fmod +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)fmodf(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return fmod_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return fmod256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return fmod512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rfmod +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)fmodf(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return fmod_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return fmod256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return fmod512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_logaddexp +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + float max_xy = std::max(x, y); + float min_xy = std::min(x, y); + return (float)(max_xy + log1pf(expf(min_xy - max_xy))); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return logaddexp_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return logaddexp256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return logaddexp512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_floor_divide +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)floorf(x / y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return floor_divide_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return floor_divide256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return floor_divide512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rfloor_divide +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)floorf(y / x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return floor_divide_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return floor_divide256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return floor_divide512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_remainder +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)remainderf(x, y); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return remainder_ps(x, y); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return remainder256_ps(x, y); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return remainder512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + +struct binary_op_rremainder +{ + NCNN_FORCEINLINE float func(const float& x, const float& y) const + { + return (float)remainderf(y, x); + } +#if __SSE2__ + NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + { + return remainder_ps(y, x); + } +#if __AVX__ + NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + { + return remainder256_ps(y, x); + } +#if __AVX512F__ + NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + { + return remainder512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index ef03cd37428..b88e43f4766 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -3,6 +3,8 @@ #include "binaryop_x86.h" +#include "cpu.h" + #if __SSE2__ #include #include "sse_mathfun.h" @@ -15,13 +17,28 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + namespace ncnn { +namespace BinaryOp_x86_functor { + +#include "binaryop_functor.h" + +} // namespace BinaryOp_x86_functor + +#if NCNN_BF16 +#include "binaryop_bf16s.h" +#endif + BinaryOp_x86::BinaryOp_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } template @@ -460,534 +477,395 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, // shall never reach here } -namespace BinaryOp_x86_functor { - -struct binary_op_add +static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return x + y; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_add_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_add_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_add_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + using namespace BinaryOp_x86_functor; -struct binary_op_sub -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return x - y; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_sub_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_sub_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_sub_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + if (op_type == BinaryOp::Operation_ADD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_SUB) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MUL) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_DIV) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MAX) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MIN) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_POW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); -struct binary_op_mul -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return x * y; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_mul_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_mul_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_mul_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + // should never reach here +} -struct binary_op_div +static void binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return x / y; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_div_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_div_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_div_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; -struct binary_op_max -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return std::max(x, y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_max_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_max_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - return _mm512_max_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + const float* ptr = a.channel(q); + float* outptr = c.channel(q); -struct binary_op_min -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return std::min(x, y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_min_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_min_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_min_ps(x, y); + binary_op_vector(ptr, &b, outptr, size, 1, 1, 1, op_type); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; +} -struct binary_op_pow +static void binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)powf(x, y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return pow_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return pow256_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return pow512_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; -struct binary_op_rsub -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return y - x; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_sub_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_sub_ps(y, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - return _mm512_sub_ps(y, x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); -struct binary_op_rdiv -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return y / x; - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return _mm_div_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return _mm256_div_ps(y, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return _mm512_div_ps(y, x); + binary_op_vector(ptr, ptr1, outptr, size, size, 1, 1, op_type); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; +} -struct binary_op_rpow +static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)powf(y, x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return pow_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return pow256_ps(y, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + if (b.w * b.h * b.d * b.c * b.elempack == 1) { - return pow512_ps(y, x); + return binary_op_scalar(a, b[0], c, op_type, opt); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; -struct binary_op_atan2 -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)atan2f(x, y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return atan2_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return atan2256_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + if (a.dims == b.dims && a.w == b.w && a.h == b.h && a.d == b.d && a.c == b.c && a.elempack == b.elempack) { - return atan2512_ps(x, y); + return binary_op_no_broadcast(a, b, c, op_type, opt); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; -struct binary_op_ratan2 -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)atan2f(y, x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return atan2_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return atan2256_ps(y, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return atan2512_ps(y, x); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + const int dims = c.dims; -struct binary_op_fmod -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)fmodf(x, y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return fmod_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return fmod256_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + if (dims == 2) { - return fmod512_ps(x, y); + const int h = c.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const int y0 = std::min(y, a.h - 1); + const int y1 = std::min(y, b.h - 1); + + const float* ptr = a.row(y0); + const float* ptr1 = b.row(y1); + float* outptr = c.row(y); + + binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); + } } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; -struct binary_op_rfmod -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const + if (dims == 3 || dims == 4) { - return (float)fmodf(y, x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return fmod_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return fmod256_ps(y, x); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return fmod512_ps(y, x); + const int channels = c.c; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const int q0 = std::min(q, a.c - 1); + const int q1 = std::min(q, b.c - 1); + + if (b.d * b.h * b.w == 1) + { + const float* ptr = a.channel(q0); + const float* ptr1 = b.channel(q1); + float* outptr = c.channel(q); + + binary_op_vector(ptr, ptr1, outptr, a.w * a.h * a.d, 1, a.elempack, b.elempack, op_type); + continue; + } + + if (b.h * b.w == 1) + { + for (int z = 0; z < c.d; z++) + { + const int z0 = std::min(z, a.d - 1); + const int z1 = std::min(z, b.d - 1); + + const float* ptr = a.channel(q0).depth(z0); + const float* ptr1 = b.channel(q1).depth(z1); + float* outptr = c.channel(q).depth(z); + + binary_op_vector(ptr, ptr1, outptr, a.w * a.h, 1, a.elempack, b.elempack, op_type); + } + continue; + } + + for (int z = 0; z < c.d; z++) + { + const int z0 = std::min(z, a.d - 1); + const int z1 = std::min(z, b.d - 1); + + for (int y = 0; y < c.h; y++) + { + const int y0 = std::min(y, a.h - 1); + const int y1 = std::min(y, b.h - 1); + + const float* ptr = a.channel(q0).depth(z0).row(y0); + const float* ptr1 = b.channel(q1).depth(z1).row(y1); + float* outptr = c.channel(q).depth(z).row(y); + + binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); + } + } + } } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; +} -struct binary_op_logaddexp +static void binary_op_scalar_inplace(Mat& a, float b, int op_type, const Option& opt) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - float max_xy = std::max(x, y); - float min_xy = std::min(x, y); - return (float)(max_xy + log1pf(expf(min_xy - max_xy))); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return logaddexp_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return logaddexp256_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - return logaddexp512_ps(x, y); + float* ptr = a.channel(q); + + binary_op_vector(ptr, &b, ptr, size, 1, 1, 1, op_type); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; +} -struct binary_op_floor_divide +static int get_reverse_op_type(int op_type) { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)floorf(x / y); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return floor_divide_ps(x, y); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const - { - return floor_divide256_ps(x, y); - } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const - { - return floor_divide512_ps(x, y); - } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; + if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; + + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; + if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; + if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; + + return op_type; +} -struct binary_op_rfloor_divide +int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)floorf(y / x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return floor_divide_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const +#if NCNN_BF16 + int elembits = std::max(bottom_blobs[0].elembits(), bottom_blobs[1].elembits()); + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif + + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + const int outdims = std::max(A.dims, B.dims); + + Mat A2 = A; + Mat B2 = B; + if (A.dims < outdims) { - return floor_divide256_ps(y, x); + // expand inner axes + if (outdims == 2) + { + if (A.w * A.elempack == B.h * B.elempack) + A2 = A.reshape(1, A.w, opt.workspace_allocator); + else // if (A.w == B.w) + { + A2.dims = 2; + A2.w = A.w * A.elempack; + A2.elempack = 1; + A2.elemsize = A.elemsize / A.elempack; + A2.cstep = A.cstep * A.elempack; + } + } + if (outdims == 3 && A.dims == 1) + { + if (A.w * A.elempack == B.c * B.elempack) + A2 = A.reshape(1, 1, A.w, opt.workspace_allocator); + else // if (A.w == B.w) + { + A2.dims = 3; + A2.w = A.w * A.elempack; + A2.elempack = 1; + A2.elemsize = A.elemsize / A.elempack; + A2.cstep = A.cstep * A.elempack; + } + } + if (outdims == 3 && A.dims == 2) + A2 = A.reshape(1, A.w, A.h, opt.workspace_allocator); + if (outdims == 4 && A.dims == 1) + { + if (A.w * A.elempack == B.c * B.elempack) + A2 = A.reshape(1, 1, 1, A.w, opt.workspace_allocator); + else // if (A.w == B.w) + { + A2.dims = 4; + A2.w = A.w * A.elempack; + A2.elempack = 1; + A2.elemsize = A.elemsize / A.elempack; + A2.cstep = A.cstep * A.elempack; + } + } + if (outdims == 4 && A.dims == 2) + A2 = A.reshape(1, 1, A.w, A.h, opt.workspace_allocator); + if (outdims == 4 && A.dims == 3) + A2 = A.reshape(1, A.w, A.h, A.c, opt.workspace_allocator); } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + if (B.dims < outdims) { - return floor_divide512_ps(y, x); + // expand inner axes + if (outdims == 2) + { + if (B.w * B.elempack == A.h * A.elempack) + B2 = B.reshape(1, B.w, opt.workspace_allocator); + else // if (B.w == A.w) + { + B2.dims = 2; + B2.w = B.w * B.elempack; + B2.elempack = 1; + B2.elemsize = B.elemsize / B.elempack; + B2.cstep = B.cstep * B.elempack; + } + } + if (outdims == 3 && B.dims == 1) + { + if (B.w * B.elempack == A.c * A.elempack) + B2 = B.reshape(1, 1, B.w, opt.workspace_allocator); + else // if (B.w == A.w) + { + B2.dims = 3; + B2.w = B.w * B.elempack; + B2.elempack = 1; + B2.elemsize = B.elemsize / B.elempack; + B2.cstep = B.cstep * B.elempack; + } + } + if (outdims == 3 && B.dims == 2) + B2 = B.reshape(1, B.w, B.h, opt.workspace_allocator); + if (outdims == 4 && B.dims == 1) + { + if (B.w * B.elempack == A.c * A.elempack) + B2 = B.reshape(1, 1, 1, B.w, opt.workspace_allocator); + else // if (B.w == A.w) + { + B2.dims = 4; + B2.w = B.w * B.elempack; + B2.elempack = 1; + B2.elemsize = B.elemsize / B.elempack; + B2.cstep = B.cstep * B.elempack; + } + } + if (outdims == 4 && B.dims == 2) + B2 = B.reshape(1, 1, B.w, B.h, opt.workspace_allocator); + if (outdims == 4 && B.dims == 3) + B2 = B.reshape(1, B.w, B.h, B.c, opt.workspace_allocator); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; -struct binary_op_remainder -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const + const int outw = std::max(A2.w, B2.w); + const int outh = std::max(A2.h, B2.h); + const int outd = std::max(A2.d, B2.d); + const int outc = std::max(A2.c, B2.c); + const size_t out_elemsize = std::max(A2.elemsize, B2.elemsize); + const int out_elempack = std::max(A2.elempack, B2.elempack); + + Mat& top_blob = top_blobs[0]; + if (outdims == 1) { - return (float)remainderf(x, y); + top_blob.create(outw, out_elemsize, out_elempack, opt.blob_allocator); } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const + if (outdims == 2) { - return remainder_ps(x, y); + top_blob.create(outw, outh, out_elemsize, out_elempack, opt.blob_allocator); } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + if (outdims == 3) { - return remainder256_ps(x, y); + top_blob.create(outw, outh, outc, out_elemsize, out_elempack, opt.blob_allocator); } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + if (outdims == 4) { - return remainder512_ps(x, y); + top_blob.create(outw, outh, outd, outc, out_elemsize, out_elempack, opt.blob_allocator); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; + if (top_blob.empty()) + return -100; -struct binary_op_rremainder -{ - NCNN_FORCEINLINE float func(const float& x, const float& y) const - { - return (float)remainderf(y, x); - } -#if __SSE2__ - NCNN_FORCEINLINE __m128 func_pack4(const __m128& x, const __m128& y) const - { - return remainder_ps(y, x); - } -#if __AVX__ - NCNN_FORCEINLINE __m256 func_pack8(const __m256& x, const __m256& y) const + const bool a_pack_is_lower = A2.elempack < B2.elempack; + const bool a_pack_is_equal = A2.elempack == B2.elempack; + const bool a_size_is_lower = A2.w * A2.h * A2.d * A2.c * A2.elempack < B2.w * B2.h * B2.d * B2.c * B2.elempack; + if (a_pack_is_lower || (a_pack_is_equal && a_size_is_lower)) { - return remainder256_ps(y, x); + binary_op_broadcast(B2, A2, top_blob, get_reverse_op_type(op_type), opt); } -#if __AVX512F__ - NCNN_FORCEINLINE __m512 func_pack16(const __m512& x, const __m512& y) const + else { - return remainder512_ps(y, x); + binary_op_broadcast(A2, B2, top_blob, op_type, opt); } -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ -}; -} // namespace BinaryOp_x86_functor + return 0; +} -static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type) +int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + + binary_op_scalar_inplace(bottom_top_blob, b, op_type, opt); + + return 0; +} + +#if NCNN_BF16 +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void binary_op_vector_bf16s_avx512bf16(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int aw, int bw, int ap, int bp, int op_type); +#endif + +static void binary_op_vector_bf16s(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int aw, int bw, int ap, int bp, int op_type) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + return binary_op_vector_bf16s_avx512bf16(ptr, ptr1, outptr, aw, bw, ap, bp, op_type); + } +#endif + using namespace BinaryOp_x86_functor; - if (op_type == BinaryOp::Operation_ADD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_SUB) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_MUL) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_DIV) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_MAX) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_MIN) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_POW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); - if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_ADD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_SUB) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MUL) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_DIV) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MAX) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MIN) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_POW) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } -static void binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) +static void binary_op_scalar_bf16s(const Mat& a, unsigned short b, Mat& c, int op_type, const Option& opt) { const int channels = a.c; const int size = a.w * a.h * a.d * a.elempack; @@ -995,14 +873,14 @@ static void binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const O #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); - binary_op_vector(ptr, &b, outptr, size, 1, 1, 1, op_type); + binary_op_vector_bf16s(ptr, &b, outptr, size, 1, 1, 1, op_type); } } -static void binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +static void binary_op_no_broadcast_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { const int channels = a.c; const int size = a.w * a.h * a.d * a.elempack; @@ -1010,24 +888,24 @@ static void binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_ty #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); - binary_op_vector(ptr, ptr1, outptr, size, size, 1, 1, op_type); + binary_op_vector_bf16s(ptr, ptr1, outptr, size, size, 1, 1, op_type); } } -static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +static void binary_op_broadcast_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { if (b.w * b.h * b.d * b.c * b.elempack == 1) { - return binary_op_scalar(a, b[0], c, op_type, opt); + return binary_op_scalar_bf16s(a, ((const unsigned short*)b)[0], c, op_type, opt); } if (a.dims == b.dims && a.w == b.w && a.h == b.h && a.d == b.d && a.c == b.c && a.elempack == b.elempack) { - return binary_op_no_broadcast(a, b, c, op_type, opt); + return binary_op_no_broadcast_bf16s(a, b, c, op_type, opt); } const int dims = c.dims; @@ -1042,11 +920,11 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const int y0 = std::min(y, a.h - 1); const int y1 = std::min(y, b.h - 1); - const float* ptr = a.row(y0); - const float* ptr1 = b.row(y1); - float* outptr = c.row(y); + const unsigned short* ptr = a.row(y0); + const unsigned short* ptr1 = b.row(y1); + unsigned short* outptr = c.row(y); - binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); + binary_op_vector_bf16s(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); } } @@ -1062,11 +940,11 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, if (b.d * b.h * b.w == 1) { - const float* ptr = a.channel(q0); - const float* ptr1 = b.channel(q1); - float* outptr = c.channel(q); + const unsigned short* ptr = a.channel(q0); + const unsigned short* ptr1 = b.channel(q1); + unsigned short* outptr = c.channel(q); - binary_op_vector(ptr, ptr1, outptr, a.w * a.h * a.d, 1, a.elempack, b.elempack, op_type); + binary_op_vector_bf16s(ptr, ptr1, outptr, a.w * a.h * a.d, 1, a.elempack, b.elempack, op_type); continue; } @@ -1077,11 +955,11 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const int z0 = std::min(z, a.d - 1); const int z1 = std::min(z, b.d - 1); - const float* ptr = a.channel(q0).depth(z0); - const float* ptr1 = b.channel(q1).depth(z1); - float* outptr = c.channel(q).depth(z); + const unsigned short* ptr = a.channel(q0).depth(z0); + const unsigned short* ptr1 = b.channel(q1).depth(z1); + unsigned short* outptr = c.channel(q).depth(z); - binary_op_vector(ptr, ptr1, outptr, a.w * a.h, 1, a.elempack, b.elempack, op_type); + binary_op_vector_bf16s(ptr, ptr1, outptr, a.w * a.h, 1, a.elempack, b.elempack, op_type); } continue; } @@ -1096,18 +974,18 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const int y0 = std::min(y, a.h - 1); const int y1 = std::min(y, b.h - 1); - const float* ptr = a.channel(q0).depth(z0).row(y0); - const float* ptr1 = b.channel(q1).depth(z1).row(y1); - float* outptr = c.channel(q).depth(z).row(y); + const unsigned short* ptr = a.channel(q0).depth(z0).row(y0); + const unsigned short* ptr1 = b.channel(q1).depth(z1).row(y1); + unsigned short* outptr = c.channel(q).depth(z).row(y); - binary_op_vector(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); + binary_op_vector_bf16s(ptr, ptr1, outptr, a.w, b.w, a.elempack, b.elempack, op_type); } } } } } -static void binary_op_scalar_inplace(Mat& a, float b, int op_type, const Option& opt) +static void binary_op_scalar_inplace_bf16s(Mat& a, unsigned short b, int op_type, const Option& opt) { const int channels = a.c; const int size = a.w * a.h * a.d * a.elempack; @@ -1115,34 +993,13 @@ static void binary_op_scalar_inplace(Mat& a, float b, int op_type, const Option& #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - float* ptr = a.channel(q); + unsigned short* ptr = a.channel(q); - binary_op_vector(ptr, &b, ptr, size, 1, 1, 1, op_type); + binary_op_vector_bf16s(ptr, &b, ptr, size, 1, 1, 1, op_type); } } -static int get_reverse_op_type(int op_type) -{ - if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; - if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; - if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; - if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2; - if (op_type == BinaryOp::Operation_FMOD) return BinaryOp::Operation_RFMOD; - if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return BinaryOp::Operation_RFLOOR_DIVIDE; - if (op_type == BinaryOp::Operation_REMAINDER) return BinaryOp::Operation_RREMAINDER; - - if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; - if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; - if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; - if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2; - if (op_type == BinaryOp::Operation_RFMOD) return BinaryOp::Operation_FMOD; - if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return BinaryOp::Operation_FLOOR_DIVIDE; - if (op_type == BinaryOp::Operation_RREMAINDER) return BinaryOp::Operation_REMAINDER; - - return op_type; -} - -int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int BinaryOp_x86::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& A = bottom_blobs[0]; const Mat& B = bottom_blobs[1]; @@ -1281,21 +1138,22 @@ int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector const bool a_size_is_lower = A2.w * A2.h * A2.d * A2.c * A2.elempack < B2.w * B2.h * B2.d * B2.c * B2.elempack; if (a_pack_is_lower || (a_pack_is_equal && a_size_is_lower)) { - binary_op_broadcast(B2, A2, top_blob, get_reverse_op_type(op_type), opt); + binary_op_broadcast_bf16s(B2, A2, top_blob, get_reverse_op_type(op_type), opt); } else { - binary_op_broadcast(A2, B2, top_blob, op_type, opt); + binary_op_broadcast_bf16s(A2, B2, top_blob, op_type, opt); } return 0; } -int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +int BinaryOp_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const { - binary_op_scalar_inplace(bottom_top_blob, b, op_type, opt); + binary_op_scalar_inplace_bf16s(bottom_top_blob, float32_to_bfloat16(b), op_type, opt); return 0; } +#endif // NCNN_BF16 } // namespace ncnn diff --git a/src/layer/x86/binaryop_x86.h b/src/layer/x86/binaryop_x86.h index f841c373d18..0a5b4feb77c 100644 --- a/src/layer/x86/binaryop_x86.h +++ b/src/layer/x86/binaryop_x86.h @@ -16,6 +16,12 @@ class BinaryOp_x86 : public BinaryOp virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/binaryop_x86_avx512bf16.cpp b/src/layer/x86/binaryop_x86_avx512bf16.cpp new file mode 100644 index 00000000000..6f6ebba76b4 --- /dev/null +++ b/src/layer/x86/binaryop_x86_avx512bf16.cpp @@ -0,0 +1,57 @@ +// Copyright 2024 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "binaryop_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +namespace ncnn { + +namespace BinaryOp_x86_functor { + +#include "binaryop_functor.h" + +} // namespace BinaryOp_x86_functor + +#include "binaryop_bf16s.h" + +void binary_op_vector_bf16s_avx512bf16(const unsigned short* ptr, const unsigned short* ptr1, unsigned short* outptr, int aw, int bw, int ap, int bp, int op_type) +{ + using namespace BinaryOp_x86_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_SUB) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MUL) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_DIV) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MAX) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_MIN) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_POW) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFMOD) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_LOGADDEXP) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_FLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RFLOOR_DIVIDE) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_RREMAINDER) return binary_op_vector_bf16s(ptr, ptr1, outptr, aw, bw, ap, bp); + + // should never reach here +} + +} // namespace ncnn From eeb2a0b48c2cd8d2276bc1c7cd6cfe36e22357c9 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Mar 2026 11:49:14 +0800 Subject: [PATCH 15/36] update pnnx ci torch 2.10.0 (#6592) * Downgrade actions/cache from v5 to v4 --- .github/workflows/pnnx.yml | 26 +++++--- tools/pnnx/src/pass_level2/nn_GRU.cpp | 64 ++++++++++++++----- tools/pnnx/src/pass_level2/nn_LSTM.cpp | 56 ++++++++++++---- .../src/pass_level5/attribute_unpooling.cpp | 3 + .../tests/onnx/test_F_adaptive_avg_pool3d.py | 2 +- .../tests/onnx/test_F_adaptive_max_pool1d.py | 2 +- .../tests/onnx/test_F_adaptive_max_pool2d.py | 2 +- .../tests/onnx/test_F_adaptive_max_pool3d.py | 2 +- tools/pnnx/tests/onnx/test_convnext_tiny.py | 2 +- .../tests/onnx/test_nn_AdaptiveAvgPool3d.py | 2 +- .../tests/onnx/test_nn_AdaptiveMaxPool1d.py | 2 +- .../tests/onnx/test_nn_AdaptiveMaxPool2d.py | 2 +- .../tests/onnx/test_nn_AdaptiveMaxPool3d.py | 2 +- tools/pnnx/tests/onnx/test_swin_t.py | 2 +- tools/pnnx/tests/onnx/test_torch_roll.py | 2 +- .../test_transformers_funnel_attention.py | 2 +- 16 files changed, 121 insertions(+), 52 deletions(-) diff --git a/.github/workflows/pnnx.yml b/.github/workflows/pnnx.yml index 96794c1f0ec..bc4c54ef021 100644 --- a/.github/workflows/pnnx.yml +++ b/.github/workflows/pnnx.yml @@ -21,11 +21,11 @@ permissions: contents: read env: - LIBTORCH_VERSION: 2.9.0 - TORCHVISION_VERSION: 0.24.0 + LIBTORCH_VERSION: 2.10.0 + TORCHVISION_VERSION: 0.25.0 PROTOBUF_VERSION: 21.12 - ONNXRUNTIME_VERSION: 1.23.1 - CACHE_DATE: 20251020 + ONNXRUNTIME_VERSION: 1.24.3 + CACHE_DATE: 20260309 SEGMENT_DOWNLOAD_TIMEOUT_MINS: 15 jobs: @@ -94,21 +94,21 @@ jobs: - name: cache-libtorch id: cache-libtorch - uses: actions/cache@v5 + uses: actions/cache@v4 with: path: libtorch-${{ env.LIBTORCH_VERSION }}-install key: libtorch-${{ env.LIBTORCH_VERSION }}-linux-install-${{ env.CACHE_DATE }} - name: cache-torchvision id: cache-torchvision - uses: actions/cache@v5 + uses: actions/cache@v4 with: path: torchvision-${{ env.TORCHVISION_VERSION }}-install key: torchvision-${{ env.TORCHVISION_VERSION }}-linux-install-${{ env.CACHE_DATE }} - name: cache-onnxruntime id: cache-onnxruntime - uses: actions/cache@v5 + uses: actions/cache@v4 with: path: onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-install key: onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-linux-install-${{ env.CACHE_DATE }} @@ -129,7 +129,9 @@ jobs: pip3 install -r requirements.txt --break-system-packages patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-fix-mobile-build.patch patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-no-link-system-lib.patch - patch -p1 -R -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-revert-nativert-api.patch + patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-fix-eigen-build.patch + patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-fix-link-local-sleef.patch + patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/pytorch-v${{ env.LIBTORCH_VERSION }}-revert-nativert-api.patch mkdir -p build && cd build cmake -DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/libtorch-${{ env.LIBTORCH_VERSION }}-install \ -DCMAKE_BUILD_TYPE=MinSizeRel \ @@ -151,7 +153,10 @@ jobs: -DUSE_NUMPY=OFF \ -DUSE_OPENMP=OFF \ -DUSE_SOURCE_DEBUG_ON_MOBILE=OFF \ - -DUSE_XNNPACK=OFF .. + -DUSE_XNNPACK=OFF \ + -DBUILD_TEST=OFF \ + -DATEN_NO_TEST=ON \ + .. cmake --build . -j 8 cmake --build . -j 8 --target install/strip @@ -192,7 +197,7 @@ jobs: cd onnxruntime-${{ env.ONNXRUNTIME_VERSION }} patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-less-mlas-features.patch patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-monolithic-static-library.patch - patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-include-cstdint.patch + patch -p1 -i $GITHUB_WORKSPACE/pnnx-patches/onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-use-clog.patch mkdir -p build2 && cd build2 cmake -DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/onnxruntime-${{ env.ONNXRUNTIME_VERSION }}-install \ -DCMAKE_BUILD_TYPE=MinSizeRel \ @@ -250,6 +255,7 @@ jobs: - { python: '3.12', numpy: '2.2.5', opencv: '4.11.*', torch: '2.7.0', torchvision: '0.22.0', torchaudio: '2.7.0+cpu', transformers: '4.52.1' } - { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.8.0', torchvision: '0.23.0', torchaudio: '2.8.0+cpu', transformers: '4.56.2' } - { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.9.0', torchvision: '0.24.0', torchaudio: '2.9.0+cpu', transformers: '4.56.2' } + - { python: '3.13', numpy: '2.2.5', opencv: '4.12.*', torch: '2.10.0', torchvision: '0.25.0', torchaudio: '2.10.0+cpu', transformers: '4.56.2' } name: test-${{ matrix.torch }}-py${{ matrix.python }} diff --git a/tools/pnnx/src/pass_level2/nn_GRU.cpp b/tools/pnnx/src/pass_level2/nn_GRU.cpp index ed563c4d42b..6404b3d55bf 100644 --- a/tools/pnnx/src/pass_level2/nn_GRU.cpp +++ b/tools/pnnx/src/pass_level2/nn_GRU.cpp @@ -32,7 +32,7 @@ pnnx.Output output 1 0 out1 return "gru"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& /*matched_operators*/, const std::map& captured_params, const std::map& captured_attrs) const { if (captured_params.find("gru.hidden_size") == captured_params.end()) return false; @@ -223,9 +223,9 @@ pnnx.Output output 1 0 out1 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_GRU_onnx::match(captured_params, captured_attrs)) + if (!nn_GRU_onnx::match(matched_operators, captured_params, captured_attrs)) return false; const int hidden_size = captured_params.at("gru.hidden_size").i; @@ -360,7 +360,7 @@ class nn_GRU_onnx_1 : public nn_GRU_onnx const char* match_pattern_graph() const { return R"PNNXIR(7767517 -7 8 +7 7 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 initial_h pnnx.Attribute W 0 1 W @data @@ -380,7 +380,7 @@ class nn_GRU_onnx_B1 : public nn_GRU_onnx_B const char* match_pattern_graph() const { return R"PNNXIR(7767517 -8 9 +8 8 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 initial_h pnnx.Attribute W 0 1 W @data @@ -453,15 +453,30 @@ pnnx.Output output 1 0 out2 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_GRU_onnx::match(captured_params, captured_attrs)) + if (!nn_GRU_onnx::match(matched_operators, captured_params, captured_attrs)) return false; - if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) - return false; + if (captured_params.at("reshape.shape").ai == std::vector{0, 0, -1}) + return true; - return true; + const Operator* op_reshape = matched_operators.at("reshape"); + const std::vector& out1_shape = op_reshape->inputs[0]->shape; + const std::vector& out2_shape = op_reshape->outputs[0]->shape; + if (out2_shape.size() == 3 && captured_params.at("reshape.shape").ai.size() == 3 && out1_shape.size() >= out2_shape.size()) + { + if (out1_shape[0] != out2_shape[0]) + return false; + if (out1_shape[1] != out2_shape[1]) + return false; + if (captured_params.at("reshape.shape").ai[2] != out2_shape[2]) + return false; + + return true; + } + + return false; } }; @@ -485,15 +500,30 @@ pnnx.Output output 1 0 out2 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_GRU_onnx_B::match(captured_params, captured_attrs)) + if (!nn_GRU_onnx_B::match(matched_operators, captured_params, captured_attrs)) return false; - if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) - return false; + if (captured_params.at("reshape.shape").ai == std::vector{0, 0, -1}) + return true; - return true; + const Operator* op_reshape = matched_operators.at("reshape"); + const std::vector& out1_shape = op_reshape->inputs[0]->shape; + const std::vector& out2_shape = op_reshape->outputs[0]->shape; + if (out2_shape.size() == 3 && captured_params.at("reshape.shape").ai.size() == 3 && out1_shape.size() >= out2_shape.size()) + { + if (out1_shape[0] != out2_shape[0]) + return false; + if (out1_shape[1] != out2_shape[1]) + return false; + if (captured_params.at("reshape.shape").ai[2] != out2_shape[2]) + return false; + + return true; + } + + return false; } }; @@ -505,7 +535,7 @@ class nn_GRU_onnx_4 : public nn_GRU_onnx_3 const char* match_pattern_graph() const { return R"PNNXIR(7767517 -8 9 +8 8 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 initial_h pnnx.Attribute W 0 1 W @data @@ -526,7 +556,7 @@ class nn_GRU_onnx_B4 : public nn_GRU_onnx_B3 const char* match_pattern_graph() const { return R"PNNXIR(7767517 -9 10 +9 9 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 initial_h pnnx.Attribute W 0 1 W @data diff --git a/tools/pnnx/src/pass_level2/nn_LSTM.cpp b/tools/pnnx/src/pass_level2/nn_LSTM.cpp index 12a9ac25ac8..90b58cc7706 100644 --- a/tools/pnnx/src/pass_level2/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level2/nn_LSTM.cpp @@ -32,7 +32,7 @@ pnnx.Output output 1 0 out1 return "lstm"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& /*matched_operators*/, const std::map& captured_params, const std::map& captured_attrs) const { if (captured_params.find("lstm.hidden_size") == captured_params.end()) return false; @@ -236,9 +236,9 @@ pnnx.Output output 1 0 out1 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_LSTM_onnx::match(captured_params, captured_attrs)) + if (!nn_LSTM_onnx::match(matched_operators, captured_params, captured_attrs)) return false; const int hidden_size = captured_params.at("lstm.hidden_size").i; @@ -482,15 +482,30 @@ pnnx.Output output 1 0 out2 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_LSTM_onnx::match(captured_params, captured_attrs)) + if (!nn_LSTM_onnx::match(matched_operators, captured_params, captured_attrs)) return false; - if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) - return false; + if (captured_params.at("reshape.shape").ai == std::vector{0, 0, -1}) + return true; - return true; + const Operator* op_reshape = matched_operators.at("reshape"); + const std::vector& out1_shape = op_reshape->inputs[0]->shape; + const std::vector& out2_shape = op_reshape->outputs[0]->shape; + if (out2_shape.size() == 3 && captured_params.at("reshape.shape").ai.size() == 3 && out1_shape.size() >= out2_shape.size()) + { + if (out1_shape[0] != out2_shape[0]) + return false; + if (out1_shape[1] != out2_shape[1]) + return false; + if (captured_params.at("reshape.shape").ai[2] != out2_shape[2]) + return false; + + return true; + } + + return false; } }; @@ -514,15 +529,30 @@ pnnx.Output output 1 0 out2 )PNNXIR"; } - bool match(const std::map& captured_params, const std::map& captured_attrs) const + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& captured_attrs) const { - if (!nn_LSTM_onnx_B::match(captured_params, captured_attrs)) + if (!nn_LSTM_onnx_B::match(matched_operators, captured_params, captured_attrs)) return false; - if (captured_params.at("reshape.shape").ai != std::vector{0, 0, -1}) - return false; + if (captured_params.at("reshape.shape").ai == std::vector{0, 0, -1}) + return true; - return true; + const Operator* op_reshape = matched_operators.at("reshape"); + const std::vector& out1_shape = op_reshape->inputs[0]->shape; + const std::vector& out2_shape = op_reshape->outputs[0]->shape; + if (out2_shape.size() == 3 && captured_params.at("reshape.shape").ai.size() == 3 && out1_shape.size() >= out2_shape.size()) + { + if (out1_shape[0] != out2_shape[0]) + return false; + if (out1_shape[1] != out2_shape[1]) + return false; + if (captured_params.at("reshape.shape").ai[2] != out2_shape[2]) + return false; + + return true; + } + + return false; } }; diff --git a/tools/pnnx/src/pass_level5/attribute_unpooling.cpp b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp index 9e453159dd0..126881bab5e 100644 --- a/tools/pnnx/src/pass_level5/attribute_unpooling.cpp +++ b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp @@ -52,7 +52,10 @@ void attribute_unpooling(Graph& graph) for (size_t j = 0; j < x->inputs.size(); j++) { if (x->inputs[j] == attr) + { x->inputs[j] = attr2; + break; + } } } diff --git a/tools/pnnx/tests/onnx/test_F_adaptive_avg_pool3d.py b/tools/pnnx/tests/onnx/test_F_adaptive_avg_pool3d.py index db29a3243c7..151e9a7480d 100644 --- a/tools/pnnx/tests/onnx/test_F_adaptive_avg_pool3d.py +++ b/tools/pnnx/tests/onnx/test_F_adaptive_avg_pool3d.py @@ -30,7 +30,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_F_adaptive_avg_pool3d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_F_adaptive_avg_pool3d.onnx") diff --git a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool1d.py b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool1d.py index 0348447046a..73a77b70254 100644 --- a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool1d.py +++ b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool1d.py @@ -25,7 +25,7 @@ def test(): a0, a1 = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_F_adaptive_max_pool1d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_F_adaptive_max_pool1d.onnx") diff --git a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool2d.py b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool2d.py index 1ed0851e14e..75685507eef 100644 --- a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool2d.py +++ b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool2d.py @@ -30,7 +30,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_F_adaptive_max_pool2d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_F_adaptive_max_pool2d.onnx") diff --git a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool3d.py b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool3d.py index eff04244e53..dc13c2d2878 100644 --- a/tools/pnnx/tests/onnx/test_F_adaptive_max_pool3d.py +++ b/tools/pnnx/tests/onnx/test_F_adaptive_max_pool3d.py @@ -30,7 +30,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_F_adaptive_max_pool3d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_F_adaptive_max_pool3d.onnx") diff --git a/tools/pnnx/tests/onnx/test_convnext_tiny.py b/tools/pnnx/tests/onnx/test_convnext_tiny.py index 69ae5c2c518..4a84bea9813 100644 --- a/tools/pnnx/tests/onnx/test_convnext_tiny.py +++ b/tools/pnnx/tests/onnx/test_convnext_tiny.py @@ -19,7 +19,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_convnext_tiny.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_convnext_tiny.onnx") diff --git a/tools/pnnx/tests/onnx/test_nn_AdaptiveAvgPool3d.py b/tools/pnnx/tests/onnx/test_nn_AdaptiveAvgPool3d.py index 77aff13e2c5..ec452dc4090 100644 --- a/tools/pnnx/tests/onnx/test_nn_AdaptiveAvgPool3d.py +++ b/tools/pnnx/tests/onnx/test_nn_AdaptiveAvgPool3d.py @@ -35,7 +35,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_nn_AdaptiveAvgPool3d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_nn_AdaptiveAvgPool3d.onnx") diff --git a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool1d.py b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool1d.py index dca92769a75..413dea99c64 100644 --- a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool1d.py +++ b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool1d.py @@ -28,7 +28,7 @@ def test(): a0, a1 = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool1d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool1d.onnx") diff --git a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool2d.py b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool2d.py index 6687cb53049..652c15a531d 100644 --- a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool2d.py +++ b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool2d.py @@ -35,7 +35,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool2d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool2d.onnx") diff --git a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool3d.py b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool3d.py index ae86998348e..9f76a7c23a0 100644 --- a/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool3d.py +++ b/tools/pnnx/tests/onnx/test_nn_AdaptiveMaxPool3d.py @@ -35,7 +35,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool3d.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_nn_AdaptiveMaxPool3d.onnx") diff --git a/tools/pnnx/tests/onnx/test_swin_t.py b/tools/pnnx/tests/onnx/test_swin_t.py index 3d1168b979e..7c7b8a7f1d6 100644 --- a/tools/pnnx/tests/onnx/test_swin_t.py +++ b/tools/pnnx/tests/onnx/test_swin_t.py @@ -19,7 +19,7 @@ def test(): a = net(x) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x,), "test_swin_t.onnx", dynamo=False) else: torch.onnx.export(net, (x,), "test_swin_t.onnx") diff --git a/tools/pnnx/tests/onnx/test_torch_roll.py b/tools/pnnx/tests/onnx/test_torch_roll.py index a31da582c0a..4cce38000f0 100644 --- a/tools/pnnx/tests/onnx/test_torch_roll.py +++ b/tools/pnnx/tests/onnx/test_torch_roll.py @@ -31,7 +31,7 @@ def test(): a = net(x, y, z) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x, y, z), "test_torch_roll.onnx", dynamo=False) else: torch.onnx.export(net, (x, y, z), "test_torch_roll.onnx") diff --git a/tools/pnnx/tests/onnx/test_transformers_funnel_attention.py b/tools/pnnx/tests/onnx/test_transformers_funnel_attention.py index b7e0743ec8e..1c6fa392376 100644 --- a/tools/pnnx/tests/onnx/test_transformers_funnel_attention.py +++ b/tools/pnnx/tests/onnx/test_transformers_funnel_attention.py @@ -45,7 +45,7 @@ def test(): a = net(x, mask0) # export onnx - if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.10'): + if version.parse(torch.__version__) >= version.parse('2.9') and version.parse(torch.__version__) < version.parse('2.11'): torch.onnx.export(net, (x, mask0), "test_transformers_funnel_attention.onnx", dynamo=False) else: torch.onnx.export(net, (x, mask0), "test_transformers_funnel_attention.onnx") From 919c896af90356d4d3eba1f6b00a0ff816c83778 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Mar 2026 17:24:02 +0800 Subject: [PATCH 16/36] x86 concat slice flatten reshape crop padding packing support fp16 bf16 storage (#6593) --- src/layer/x86/concat_x86.cpp | 815 +++++++++++++++- src/layer/x86/concat_x86.h | 5 +- src/layer/x86/crop_x86.cpp | 197 +++- src/layer/x86/crop_x86.h | 2 +- src/layer/x86/flatten_x86.cpp | 323 +++++- src/layer/x86/flatten_x86.h | 3 +- src/layer/x86/packing_x86.cpp | 1029 ++++++++++++++++++++ src/layer/x86/packing_x86.h | 1 + src/layer/x86/padding_pack16_bf16s_fp16s.h | 201 ++++ src/layer/x86/padding_pack4_bf16s_fp16s.h | 192 ++++ src/layer/x86/padding_pack8_bf16s_fp16s.h | 201 ++++ src/layer/x86/padding_x86.cpp | 508 +++++++++- src/layer/x86/padding_x86.h | 15 +- src/layer/x86/reshape_x86.cpp | 454 ++++++++- src/layer/x86/reshape_x86.h | 5 +- src/layer/x86/slice_x86.cpp | 954 +++++++++++++++++- src/layer/x86/slice_x86.h | 5 +- 17 files changed, 4874 insertions(+), 36 deletions(-) create mode 100644 src/layer/x86/padding_pack16_bf16s_fp16s.h create mode 100644 src/layer/x86/padding_pack4_bf16s_fp16s.h create mode 100644 src/layer/x86/padding_pack8_bf16s_fp16s.h diff --git a/src/layer/x86/concat_x86.cpp b/src/layer/x86/concat_x86.cpp index e34edb70e1d..14c6854740b 100644 --- a/src/layer/x86/concat_x86.cpp +++ b/src/layer/x86/concat_x86.cpp @@ -1,8 +1,10 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "concat_x86.h" +#include "cpu.h" + namespace ncnn { Concat_x86::Concat_x86() @@ -10,10 +12,19 @@ Concat_x86::Concat_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Concat_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { + int elembits = bottom_blobs[0].elembits(); + + if (elembits == 16) + return forward_bf16s_fp16s(bottom_blobs, top_blobs, opt); + int dims = bottom_blobs[0].dims; int positive_axis = axis < 0 ? dims + axis : axis; @@ -814,4 +825,806 @@ int Concat_x86::forward(const std::vector& bottom_blobs, std::vector& return 0; } +int Concat_x86::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int dims = bottom_blobs[0].dims; + int positive_axis = axis < 0 ? dims + axis : axis; + + if (dims == 1) // positive_axis == 0 + { + // concat vector + // total length + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + int top_w = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + top_w += bottom_blob.w * bottom_blob.elempack; + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = top_w % 16 == 0 ? 16 : top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1; +#else + out_elempack = top_w % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[0]; + top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + unsigned short* outptr = top_blob; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + + const unsigned short* ptr = bottom_blob; + memcpy(outptr, ptr, bottom_blob.w * bottom_blob.elemsize); + + outptr += bottom_blob.w * bottom_blob.elempack; + } + } + + if (dims == 2 && positive_axis == 0) + { + // concat image + int w = bottom_blobs[0].w; + + // total height + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + int top_h = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + elemsize = std::min(elemsize, bottom_blob.elemsize); + elempack = std::min(elempack, bottom_blob.elempack); + top_h += bottom_blob.h * bottom_blob.elempack; + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = top_h % 16 == 0 ? 16 : top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1; +#else + out_elempack = top_h % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[0]; + top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + Mat top_blob_unpacked = top_blob; + if (elempack < out_elempack) + { + top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_allocator); + if (top_blob_unpacked.empty()) + return -100; + } + + unsigned short* outptr = top_blob_unpacked; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + +#if __AVX__ +#if __AVX512F__ + if (bottom_blob.elempack == 16 && elempack == 8) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w * 8; + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr1[0] = r0[8]; + outptr1[1] = r0[9]; + outptr1[2] = r0[10]; + outptr1[3] = r0[11]; + outptr1[4] = r0[12]; + outptr1[5] = r0[13]; + outptr1[6] = r0[14]; + outptr1[7] = r0[15]; + + outptr0 += 8; + outptr1 += 8; + r0 += 16; + } + + outptr += w * 16; + } + } + if (bottom_blob.elempack == 16 && elempack == 4) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w * 4; + unsigned short* outptr2 = outptr + w * 8; + unsigned short* outptr3 = outptr + w * 12; + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + outptr2[0] = r0[8]; + outptr2[1] = r0[9]; + outptr2[2] = r0[10]; + outptr2[3] = r0[11]; + outptr3[0] = r0[12]; + outptr3[1] = r0[13]; + outptr3[2] = r0[14]; + outptr3[3] = r0[15]; + + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + r0 += 16; + } + + outptr += w * 16; + } + } + if (bottom_blob.elempack == 16 && elempack == 1) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w; + unsigned short* outptr2 = outptr + w * 2; + unsigned short* outptr3 = outptr + w * 3; + unsigned short* outptr4 = outptr + w * 4; + unsigned short* outptr5 = outptr + w * 5; + unsigned short* outptr6 = outptr + w * 6; + unsigned short* outptr7 = outptr + w * 7; + unsigned short* outptr8 = outptr + w * 8; + unsigned short* outptr9 = outptr + w * 9; + unsigned short* outptra = outptr + w * 10; + unsigned short* outptrb = outptr + w * 11; + unsigned short* outptrc = outptr + w * 12; + unsigned short* outptrd = outptr + w * 13; + unsigned short* outptre = outptr + w * 14; + unsigned short* outptrf = outptr + w * 15; + + for (int j = 0; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + *outptr8++ = r0[8]; + *outptr9++ = r0[9]; + *outptra++ = r0[10]; + *outptrb++ = r0[11]; + *outptrc++ = r0[12]; + *outptrd++ = r0[13]; + *outptre++ = r0[14]; + *outptrf++ = r0[15]; + + r0 += 16; + } + + outptr += w * 16; + } + } +#endif // __AVX512F__ + if (bottom_blob.elempack == 8 && elempack == 4) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w * 4; + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + + outptr0 += 4; + outptr1 += 4; + r0 += 8; + } + + outptr += w * 8; + } + } + if (bottom_blob.elempack == 8 && elempack == 1) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w; + unsigned short* outptr2 = outptr + w * 2; + unsigned short* outptr3 = outptr + w * 3; + unsigned short* outptr4 = outptr + w * 4; + unsigned short* outptr5 = outptr + w * 5; + unsigned short* outptr6 = outptr + w * 6; + unsigned short* outptr7 = outptr + w * 7; + + for (int j = 0; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + + r0 += 8; + } + + outptr += w * 8; + } + } +#endif // __AVX__ + if (bottom_blob.elempack == 4 && elempack == 1) + { + for (int i = 0; i < bottom_blob.h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = outptr; + unsigned short* outptr1 = outptr + w; + unsigned short* outptr2 = outptr + w * 2; + unsigned short* outptr3 = outptr + w * 3; + + for (int j = 0; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + + r0 += 4; + } + + outptr += w * 4; + } + } + if (bottom_blob.elempack == elempack) // 1-1 4-4 8-8 16-16 + { + int size = w * bottom_blob.h; + + const unsigned short* ptr = bottom_blob; + memcpy(outptr, ptr, size * bottom_blob.elemsize); + + outptr += size * bottom_blob.elempack; + } + } + + // packing + if (elempack < out_elempack) + { + convert_packing(top_blob_unpacked, top_blob, out_elempack, opt); + } + } + + if (dims == 2 && positive_axis == 1) + { + // interleave image row + int h = bottom_blobs[0].h; + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + + // total width + int top_w = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + top_w += bottom_blob.w; + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(top_w, h, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* outptr = top_blob.row(i); + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + + const unsigned short* ptr = bottom_blob.row(i); + memcpy(outptr, ptr, bottom_blob.w * elemsize); + + outptr += bottom_blob.w * elempack; + } + } + } + + if ((dims == 3 || dims == 4) && positive_axis == 0) + { + // concat dim + int w = bottom_blobs[0].w; + int h = bottom_blobs[0].h; + int d = bottom_blobs[0].d; + + // total channels + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + int top_channels = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + elemsize = std::min(elemsize, bottom_blob.elemsize); + elempack = std::min(elempack, bottom_blob.elempack); + top_channels += bottom_blob.c * bottom_blob.elempack; + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = top_channels % 16 == 0 ? 16 : top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1; +#else + out_elempack = top_channels % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[0]; + top_blob.create(w, h, d, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + Mat top_blob_unpacked = top_blob; + if (elempack < out_elempack) + { + top_blob_unpacked.create(w, h, d, top_channels / elempack, elemsize, elempack, opt.workspace_allocator); + if (top_blob_unpacked.empty()) + return -100; + + top_blob_unpacked.dims = dims; + } + + int p = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + +#if __AVX__ +#if __AVX512F__ + if (bottom_blob.elempack == 16 && elempack == 8) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr1[0] = r0[8]; + outptr1[1] = r0[9]; + outptr1[2] = r0[10]; + outptr1[3] = r0[11]; + outptr1[4] = r0[12]; + outptr1[5] = r0[13]; + outptr1[6] = r0[14]; + outptr1[7] = r0[15]; + + outptr0 += 8; + outptr1 += 8; + r0 += 16; + } + + p += 2; + } + } + if (bottom_blob.elempack == 16 && elempack == 4) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + unsigned short* outptr2 = top_blob_unpacked.channel(p + 2); + unsigned short* outptr3 = top_blob_unpacked.channel(p + 3); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + outptr2[0] = r0[8]; + outptr2[1] = r0[9]; + outptr2[2] = r0[10]; + outptr2[3] = r0[11]; + outptr3[0] = r0[12]; + outptr3[1] = r0[13]; + outptr3[2] = r0[14]; + outptr3[3] = r0[15]; + + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + r0 += 16; + } + + p += 4; + } + } + if (bottom_blob.elempack == 16 && elempack == 1) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + unsigned short* outptr2 = top_blob_unpacked.channel(p + 2); + unsigned short* outptr3 = top_blob_unpacked.channel(p + 3); + unsigned short* outptr4 = top_blob_unpacked.channel(p + 4); + unsigned short* outptr5 = top_blob_unpacked.channel(p + 5); + unsigned short* outptr6 = top_blob_unpacked.channel(p + 6); + unsigned short* outptr7 = top_blob_unpacked.channel(p + 7); + unsigned short* outptr8 = top_blob_unpacked.channel(p + 8); + unsigned short* outptr9 = top_blob_unpacked.channel(p + 9); + unsigned short* outptra = top_blob_unpacked.channel(p + 10); + unsigned short* outptrb = top_blob_unpacked.channel(p + 11); + unsigned short* outptrc = top_blob_unpacked.channel(p + 12); + unsigned short* outptrd = top_blob_unpacked.channel(p + 13); + unsigned short* outptre = top_blob_unpacked.channel(p + 14); + unsigned short* outptrf = top_blob_unpacked.channel(p + 15); + + for (int i = 0; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + *outptr8++ = r0[8]; + *outptr9++ = r0[9]; + *outptra++ = r0[10]; + *outptrb++ = r0[11]; + *outptrc++ = r0[12]; + *outptrd++ = r0[13]; + *outptre++ = r0[14]; + *outptrf++ = r0[15]; + + r0 += 16; + } + + p += 16; + } + } +#endif // __AVX512F__ + if (bottom_blob.elempack == 8 && elempack == 4) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + + outptr0 += 4; + outptr1 += 4; + r0 += 8; + } + + p += 2; + } + } + if (bottom_blob.elempack == 8 && elempack == 1) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + unsigned short* outptr2 = top_blob_unpacked.channel(p + 2); + unsigned short* outptr3 = top_blob_unpacked.channel(p + 3); + unsigned short* outptr4 = top_blob_unpacked.channel(p + 4); + unsigned short* outptr5 = top_blob_unpacked.channel(p + 5); + unsigned short* outptr6 = top_blob_unpacked.channel(p + 6); + unsigned short* outptr7 = top_blob_unpacked.channel(p + 7); + + for (int i = 0; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + + r0 += 8; + } + + p += 8; + } + } +#endif // __AVX__ + if (bottom_blob.elempack == 4 && elempack == 1) + { + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + for (int q = 0; q < bottom_blob.c; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob_unpacked.channel(p); + unsigned short* outptr1 = top_blob_unpacked.channel(p + 1); + unsigned short* outptr2 = top_blob_unpacked.channel(p + 2); + unsigned short* outptr3 = top_blob_unpacked.channel(p + 3); + + for (int i = 0; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + + r0 += 4; + } + + p += 4; + } + } + if (bottom_blob.elempack == elempack) // 1-1 4-4 8-8 + { + int size = bottom_blob.total(); + + const unsigned short* ptr = bottom_blob; + unsigned short* outptr = top_blob_unpacked.channel(p); + memcpy(outptr, ptr, size * bottom_blob.elemsize); + + p += bottom_blob.c; + } + } + + // packing + if (elempack < out_elempack) + { + convert_packing(top_blob_unpacked, top_blob, out_elempack, opt); + } + } + + if ((dims == 3 && positive_axis == 1) || (dims == 4 && positive_axis == 2)) + { + // interleave dim height + int w = bottom_blobs[0].w; + int d = bottom_blobs[0].d; + int channels = bottom_blobs[0].c; + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + + // total height + int top_h = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + top_h += bottom_blob.h; + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(w, top_h, d, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < d; i++) + { + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + + int size = bottom_blob.w * bottom_blob.h; + + const unsigned short* ptr = bottom_blob.channel(q).depth(i); + memcpy(outptr, ptr, size * elemsize); + + outptr += size * elempack; + } + } + } + } + + if ((dims == 3 && positive_axis == 2) || (dims == 4 && positive_axis == 3)) + { + // interleave dim width + int h = bottom_blobs[0].h; + int d = bottom_blobs[0].d; + int channels = bottom_blobs[0].c; + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + + // total height + int top_w = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + top_w += bottom_blob.w; + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(top_w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < d; i++) + { + for (int j = 0; j < h; j++) + { + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + + const unsigned short* ptr = bottom_blob.channel(q).depth(i).row(j); + memcpy(outptr, ptr, bottom_blob.w * elemsize); + + outptr += bottom_blob.w * elempack; + } + } + } + } + } + + if (dims == 4 && positive_axis == 1) + { + // interleave dim depth + int w = bottom_blobs[0].w; + int h = bottom_blobs[0].h; + int channels = bottom_blobs[0].c; + size_t elemsize = bottom_blobs[0].elemsize; + int elempack = bottom_blobs[0].elempack; + + // total depth + int top_d = 0; + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + top_d += bottom_blob.d; + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(w, h, top_d, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* outptr = top_blob.channel(q); + + for (size_t b = 0; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob = bottom_blobs[b]; + + int size = bottom_blob.w * bottom_blob.h * bottom_blob.d; + + const unsigned short* ptr = bottom_blob.channel(q); + memcpy(outptr, ptr, size * elemsize); + + outptr += size * elempack; + } + } + } + + return 0; +} + } // namespace ncnn diff --git a/src/layer/x86/concat_x86.h b/src/layer/x86/concat_x86.h index e3d382695d6..2ba283a1e68 100644 --- a/src/layer/x86/concat_x86.h +++ b/src/layer/x86/concat_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_CONCAT_X86_H @@ -14,6 +14,9 @@ class Concat_x86 : public Concat Concat_x86(); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: + int forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; }; } // namespace ncnn diff --git a/src/layer/x86/crop_x86.cpp b/src/layer/x86/crop_x86.cpp index 6097c68daf2..af5f60ee2e7 100644 --- a/src/layer/x86/crop_x86.cpp +++ b/src/layer/x86/crop_x86.cpp @@ -1,8 +1,11 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "crop_x86.h" +#include "cpu.h" + +#include #if __SSE2__ #include #if __AVX__ @@ -17,6 +20,10 @@ Crop_x86::Crop_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } #if __SSE2__ @@ -44,6 +51,29 @@ static void crop_pack16_avx512(const Mat& src, Mat& dst, int top, int left) ptr += (left + right) * 16; } } + +static void crop_pack16_bf16s_fp16s_avx512(const Mat& src, Mat& dst, int top, int left) +{ + int w = dst.w; + int h = dst.h; + int right = src.w - dst.w - left; + + const unsigned short* ptr = (const unsigned short*)src.row(top) + left * 16; + unsigned short* outptr = dst; + + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr += 16; + outptr += 16; + } + + ptr += (left + right) * 16; + } +} #endif // __AVX512F__ static void crop_pack8_avx(const Mat& src, Mat& dst, int top, int left) @@ -68,6 +98,29 @@ static void crop_pack8_avx(const Mat& src, Mat& dst, int top, int left) ptr += (left + right) * 8; } } + +static void crop_pack8_bf16s_fp16s_avx(const Mat& src, Mat& dst, int top, int left) +{ + int w = dst.w; + int h = dst.h; + int right = src.w - dst.w - left; + + const unsigned short* ptr = (const unsigned short*)src.row(top) + left * 8; + unsigned short* outptr = dst; + + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)ptr); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr += 8; + outptr += 8; + } + + ptr += (left + right) * 8; + } +} #endif // __AVX__ static void crop_pack4_sse(const Mat& src, Mat& dst, int top, int left) @@ -92,6 +145,28 @@ static void crop_pack4_sse(const Mat& src, Mat& dst, int top, int left) ptr += (left + right) * 4; } } + +static void crop_pack4_bf16s_fp16s_sse(const Mat& src, Mat& dst, int top, int left) +{ + int w = dst.w; + int h = dst.h; + int right = src.w - dst.w - left; + + const unsigned short* ptr = (const unsigned short*)src.row(top) + left * 4; + unsigned short* outptr = dst; + + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + *(int64_t*)outptr = *(const int64_t*)ptr; + ptr += 4; + outptr += 4; + } + + ptr += (left + right) * 4; + } +} #endif // __SSE2__ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const @@ -139,7 +214,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack16_avx512(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack16_avx512(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -162,7 +240,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack16_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack16_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -199,7 +280,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack16_avx512(m, borderm, _hoffset, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(m, borderm, _hoffset, _woffset); + else + crop_pack16_avx512(m, borderm, _hoffset, _woffset); } return 0; @@ -239,7 +323,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack16_avx512(m, borderm, _hoffset, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(m, borderm, _hoffset, _woffset); + else + crop_pack16_avx512(m, borderm, _hoffset, _woffset); } } @@ -268,7 +355,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack8_avx(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack8_avx(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -291,7 +381,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack8_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack8_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -328,7 +421,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack8_avx(m, borderm, _hoffset, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(m, borderm, _hoffset, _woffset); + else + crop_pack8_avx(m, borderm, _hoffset, _woffset); } return 0; @@ -368,7 +464,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack8_avx(m, borderm, _hoffset, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(m, borderm, _hoffset, _woffset); + else + crop_pack8_avx(m, borderm, _hoffset, _woffset); } } @@ -397,7 +496,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack4_sse(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack4_sse(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -420,7 +522,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob.empty()) return -100; - crop_pack4_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack4_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -458,7 +563,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack4_sse(m, borderm, _hoffset, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(m, borderm, _hoffset, _woffset); + else + crop_pack4_sse(m, borderm, _hoffset, _woffset); } return 0; @@ -499,7 +607,10 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack4_sse(m, borderm, _hoffset, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(m, borderm, _hoffset, _woffset); + else + crop_pack4_sse(m, borderm, _hoffset, _woffset); } } @@ -580,7 +691,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack16_avx512(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack16_avx512(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -603,7 +717,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack16_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack16_avx512(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -640,7 +757,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to { const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack16_avx512(m, borderm, _hoffset, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(m, borderm, _hoffset, _woffset); + else + crop_pack16_avx512(m, borderm, _hoffset, _woffset); } return 0; @@ -680,7 +800,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to { const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack16_avx512(m, borderm, _hoffset, _woffset); + if (elemsize == 32u) + crop_pack16_bf16s_fp16s_avx512(m, borderm, _hoffset, _woffset); + else + crop_pack16_avx512(m, borderm, _hoffset, _woffset); } } @@ -709,7 +832,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack8_avx(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack8_avx(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -732,7 +858,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack8_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack8_avx(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -769,7 +898,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to { const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack8_avx(m, borderm, _hoffset, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(m, borderm, _hoffset, _woffset); + else + crop_pack8_avx(m, borderm, _hoffset, _woffset); } return 0; @@ -809,7 +941,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to { const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack8_avx(m, borderm, _hoffset, _woffset); + if (elemsize == 16u) + crop_pack8_bf16s_fp16s_avx(m, borderm, _hoffset, _woffset); + else + crop_pack8_avx(m, borderm, _hoffset, _woffset); } } @@ -838,7 +973,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack4_sse(bottom_blob, top_blob, 0, _woffset / elempack); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, 0, _woffset / elempack); + else + crop_pack4_sse(bottom_blob, top_blob, 0, _woffset / elempack); return 0; } @@ -861,7 +999,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; - crop_pack4_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); + else + crop_pack4_sse(bottom_blob, top_blob, _hoffset / elempack, _woffset); return 0; } @@ -899,7 +1040,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat m = bottom_blob_sliced.channel(q); Mat borderm = top_blob.channel(q); - crop_pack4_sse(m, borderm, _hoffset, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(m, borderm, _hoffset, _woffset); + else + crop_pack4_sse(m, borderm, _hoffset, _woffset); } return 0; @@ -940,7 +1084,10 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat m = bottom_blob_sliced.channel(q).depth(z + _doffset); Mat borderm = top_blob.channel(q).depth(z); - crop_pack4_sse(m, borderm, _hoffset, _woffset); + if (elemsize == 8u) + crop_pack4_bf16s_fp16s_sse(m, borderm, _hoffset, _woffset); + else + crop_pack4_sse(m, borderm, _hoffset, _woffset); } } diff --git a/src/layer/x86/crop_x86.h b/src/layer/x86/crop_x86.h index 85f460fc7a7..80657150bc4 100644 --- a/src/layer/x86/crop_x86.h +++ b/src/layer/x86/crop_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_CROP_X86_H diff --git a/src/layer/x86/flatten_x86.cpp b/src/layer/x86/flatten_x86.cpp index 241b4c26f0c..f43499297b6 100644 --- a/src/layer/x86/flatten_x86.cpp +++ b/src/layer/x86/flatten_x86.cpp @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "flatten_x86.h" @@ -10,6 +10,7 @@ #endif #endif // __SSE2__ +#include "cpu.h" #include "x86_usability.h" namespace ncnn { @@ -19,6 +20,10 @@ Flatten_x86::Flatten_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Flatten_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const @@ -28,6 +33,9 @@ int Flatten_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op if (elembits == 8) return forward_int8(bottom_blob, top_blob, opt); + if (elembits == 16) + return forward_bf16s_fp16s(bottom_blob, top_blob, opt); + int dims = bottom_blob.dims; if (dims == 1) @@ -552,6 +560,319 @@ int Flatten_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op return 0; } +int Flatten_x86::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int dims = bottom_blob.dims; + + if (dims == 1) + { + top_blob = bottom_blob; + return 0; + } + + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + int size = w * h * d; + + int total = size * channels * elempack; + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = total % 16 == 0 ? 16 : total % 8 == 0 ? 8 : total % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = total % 8 == 0 ? 8 : total % 4 == 0 ? 4 : 1; +#else + out_elempack = total % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (out_elempack == 1) + { + return Flatten::forward(bottom_blob, top_blob, opt); + } + + if (dims == 2 && elempack == 1) // out_elempack == 4 || out_elempack == 8 || out_elempack == 16 + { + top_blob = bottom_blob; + top_blob.dims = 1; + top_blob.w = total / out_elempack; + top_blob.h = 1; + top_blob.cstep = bottom_blob.cstep / out_elempack; + top_blob.elemsize = out_elemsize; + top_blob.elempack = out_elempack; + return 0; + } + + top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (dims == 2) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) // out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + unsigned short* outptr0 = (unsigned short*)top_blob + w * i * 16; + unsigned short* outptr1 = (unsigned short*)top_blob + w * (i * 16 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + w * (i * 16 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + w * (i * 16 + 3); + unsigned short* outptr4 = (unsigned short*)top_blob + w * (i * 16 + 4); + unsigned short* outptr5 = (unsigned short*)top_blob + w * (i * 16 + 5); + unsigned short* outptr6 = (unsigned short*)top_blob + w * (i * 16 + 6); + unsigned short* outptr7 = (unsigned short*)top_blob + w * (i * 16 + 7); + unsigned short* outptr8 = (unsigned short*)top_blob + w * (i * 16 + 8); + unsigned short* outptr9 = (unsigned short*)top_blob + w * (i * 16 + 9); + unsigned short* outptra = (unsigned short*)top_blob + w * (i * 16 + 10); + unsigned short* outptrb = (unsigned short*)top_blob + w * (i * 16 + 11); + unsigned short* outptrc = (unsigned short*)top_blob + w * (i * 16 + 12); + unsigned short* outptrd = (unsigned short*)top_blob + w * (i * 16 + 13); + unsigned short* outptre = (unsigned short*)top_blob + w * (i * 16 + 14); + unsigned short* outptrf = (unsigned short*)top_blob + w * (i * 16 + 15); + + for (int j = 0; j < w; j++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + *outptr4++ = ptr[4]; + *outptr5++ = ptr[5]; + *outptr6++ = ptr[6]; + *outptr7++ = ptr[7]; + *outptr8++ = ptr[8]; + *outptr9++ = ptr[9]; + *outptra++ = ptr[10]; + *outptrb++ = ptr[11]; + *outptrc++ = ptr[12]; + *outptrd++ = ptr[13]; + *outptre++ = ptr[14]; + *outptrf++ = ptr[15]; + + ptr += 16; + } + } + } +#endif // __AVX512F__ + + if (elempack == 8) // out_elempack == 8 || out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + unsigned short* outptr0 = (unsigned short*)top_blob + w * i * 8; + unsigned short* outptr1 = (unsigned short*)top_blob + w * (i * 8 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + w * (i * 8 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + w * (i * 8 + 3); + unsigned short* outptr4 = (unsigned short*)top_blob + w * (i * 8 + 4); + unsigned short* outptr5 = (unsigned short*)top_blob + w * (i * 8 + 5); + unsigned short* outptr6 = (unsigned short*)top_blob + w * (i * 8 + 6); + unsigned short* outptr7 = (unsigned short*)top_blob + w * (i * 8 + 7); + + for (int j = 0; j < w; j++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + *outptr4++ = ptr[4]; + *outptr5++ = ptr[5]; + *outptr6++ = ptr[6]; + *outptr7++ = ptr[7]; + + ptr += 8; + } + } + } +#endif // __AVX__ + + if (elempack == 4) // out_elempack == 4 || out_elempack == 8 || out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + unsigned short* outptr0 = (unsigned short*)top_blob + w * i * 4; + unsigned short* outptr1 = (unsigned short*)top_blob + w * (i * 4 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + w * (i * 4 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + w * (i * 4 + 3); + + for (int j = 0; j < w; j++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + + ptr += 4; + } + } + } +#endif // __SSE2__ + } + + if (dims == 3 || dims == 4) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) // out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + unsigned short* outptr0 = (unsigned short*)top_blob + size * q * 16; + unsigned short* outptr1 = (unsigned short*)top_blob + size * (q * 16 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + size * (q * 16 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + size * (q * 16 + 3); + unsigned short* outptr4 = (unsigned short*)top_blob + size * (q * 16 + 4); + unsigned short* outptr5 = (unsigned short*)top_blob + size * (q * 16 + 5); + unsigned short* outptr6 = (unsigned short*)top_blob + size * (q * 16 + 6); + unsigned short* outptr7 = (unsigned short*)top_blob + size * (q * 16 + 7); + unsigned short* outptr8 = (unsigned short*)top_blob + size * (q * 16 + 8); + unsigned short* outptr9 = (unsigned short*)top_blob + size * (q * 16 + 9); + unsigned short* outptra = (unsigned short*)top_blob + size * (q * 16 + 10); + unsigned short* outptrb = (unsigned short*)top_blob + size * (q * 16 + 11); + unsigned short* outptrc = (unsigned short*)top_blob + size * (q * 16 + 12); + unsigned short* outptrd = (unsigned short*)top_blob + size * (q * 16 + 13); + unsigned short* outptre = (unsigned short*)top_blob + size * (q * 16 + 14); + unsigned short* outptrf = (unsigned short*)top_blob + size * (q * 16 + 15); + + for (int i = 0; i < size; i++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + *outptr4++ = ptr[4]; + *outptr5++ = ptr[5]; + *outptr6++ = ptr[6]; + *outptr7++ = ptr[7]; + *outptr8++ = ptr[8]; + *outptr9++ = ptr[9]; + *outptra++ = ptr[10]; + *outptrb++ = ptr[11]; + *outptrc++ = ptr[12]; + *outptrd++ = ptr[13]; + *outptre++ = ptr[14]; + *outptrf++ = ptr[15]; + + ptr += 16; + } + } + } +#endif // __AVX512F__ + + if (elempack == 8) // out_elempack == 8 || out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + unsigned short* outptr0 = (unsigned short*)top_blob + size * q * 8; + unsigned short* outptr1 = (unsigned short*)top_blob + size * (q * 8 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + size * (q * 8 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + size * (q * 8 + 3); + unsigned short* outptr4 = (unsigned short*)top_blob + size * (q * 8 + 4); + unsigned short* outptr5 = (unsigned short*)top_blob + size * (q * 8 + 5); + unsigned short* outptr6 = (unsigned short*)top_blob + size * (q * 8 + 6); + unsigned short* outptr7 = (unsigned short*)top_blob + size * (q * 8 + 7); + + for (int i = 0; i < size; i++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + *outptr4++ = ptr[4]; + *outptr5++ = ptr[5]; + *outptr6++ = ptr[6]; + *outptr7++ = ptr[7]; + + ptr += 8; + } + } + } +#endif // __AVX__ + + if (elempack == 4) // out_elempack == 4 || out_elempack == 8 || out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + unsigned short* outptr0 = (unsigned short*)top_blob + size * q * 4; + unsigned short* outptr1 = (unsigned short*)top_blob + size * (q * 4 + 1); + unsigned short* outptr2 = (unsigned short*)top_blob + size * (q * 4 + 2); + unsigned short* outptr3 = (unsigned short*)top_blob + size * (q * 4 + 3); + + for (int i = 0; i < size; i++) + { + *outptr0++ = ptr[0]; + *outptr1++ = ptr[1]; + *outptr2++ = ptr[2]; + *outptr3++ = ptr[3]; + + ptr += 4; + } + } + } +#endif // __SSE2__ + + if (elempack == 1) // out_elempack == 4 || out_elempack == 8 || out_elempack == 16 + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + unsigned short* outptr = (unsigned short*)top_blob + size * q; + + int i = 0; +#if __SSE2__ +#if __AVX__ + for (; i + 15 < size; i += 16) + { + __m256i _v = _mm256_loadu_si256((const __m256i*)ptr); + _mm256_storeu_si256((__m256i*)outptr, _v); + ptr += 16; + outptr += 16; + } +#endif + for (; i + 7 < size; i += 8) + { + __m128i _v = _mm_loadu_si128((const __m128i*)ptr); + _mm_storeu_si128((__m128i*)outptr, _v); + ptr += 8; + outptr += 8; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = *ptr++; + } + } + } + } + + return 0; +} + int Flatten_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int dims = bottom_blob.dims; diff --git a/src/layer/x86/flatten_x86.h b/src/layer/x86/flatten_x86.h index 7d4d6315dea..7a328b1a479 100644 --- a/src/layer/x86/flatten_x86.h +++ b/src/layer/x86/flatten_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_FLATTEN_X86_H @@ -16,6 +16,7 @@ class Flatten_x86 : public Flatten virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; protected: + int forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; }; diff --git a/src/layer/x86/packing_x86.cpp b/src/layer/x86/packing_x86.cpp index c866eadd079..b6211419d84 100644 --- a/src/layer/x86/packing_x86.cpp +++ b/src/layer/x86/packing_x86.cpp @@ -3,6 +3,7 @@ #include "packing_x86.h" +#include "cpu.h" #include "x86_usability.h" namespace ncnn { @@ -10,6 +11,10 @@ namespace ncnn { Packing_x86::Packing_x86() { support_packing = true; + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Packing_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const @@ -19,6 +24,9 @@ int Packing_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op if (elembits == 8) return forward_int8(bottom_blob, top_blob, opt); + if (elembits == 16) + return forward_bf16s_fp16s(bottom_blob, top_blob, opt); + if (use_padding) { return Packing::forward(bottom_blob, top_blob, opt); @@ -1386,6 +1394,1027 @@ int Packing_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op return 0; } +int Packing_x86::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + if (use_padding) + { + return Packing::forward(bottom_blob, top_blob, opt); + } + + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + if (elempack == out_elempack) + { + top_blob = bottom_blob; + return 0; + } + + bool pack1to4 = elempack == 1 && out_elempack == 4; + bool pack4to1 = elempack == 4 && out_elempack == 1; + bool pack1to8 = elempack == 1 && out_elempack == 8; + bool pack8to1 = elempack == 8 && out_elempack == 1; + bool pack4to8 = elempack == 4 && out_elempack == 8; + bool pack8to4 = elempack == 8 && out_elempack == 4; + bool pack1to16 = elempack == 1 && out_elempack == 16; + bool pack16to1 = elempack == 16 && out_elempack == 1; + bool pack4to16 = elempack == 4 && out_elempack == 16; + bool pack16to4 = elempack == 16 && out_elempack == 4; + bool pack8to16 = elempack == 8 && out_elempack == 16; + bool pack16to8 = elempack == 16 && out_elempack == 8; + + if (!pack1to4 && !pack4to1 && !pack1to8 && !pack8to1 && !pack4to8 && !pack8to4 && !pack1to16 && !pack16to1 && !pack4to16 && !pack16to4 && !pack8to16 && !pack16to8) + { + return Packing::forward(bottom_blob, top_blob, opt); + } + + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + int dims = bottom_blob.dims; + + if (!use_padding) + { + // identity if use_padding not allowed + if (dims == 1 && w * elempack % out_elempack != 0) + { + top_blob = bottom_blob; + return 0; + } + if (dims == 2 && h * elempack % out_elempack != 0) + { + top_blob = bottom_blob; + return 0; + } + if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0) + { + top_blob = bottom_blob; + return 0; + } + } + + if (dims == 1) + { + top_blob = bottom_blob; + top_blob.w = w * elempack / out_elempack; + top_blob.cstep = bottom_blob.cstep * elempack / out_elempack; + top_blob.elemsize = elemsize / elempack * out_elempack; + top_blob.elempack = out_elempack; + return 0; + } + + if (dims == 2) + { + int outh = h * elempack / out_elempack; + size_t out_elemsize = elemsize / elempack * out_elempack; + + top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (pack1to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 4); + const unsigned short* r1 = bottom_blob.row(i * 4 + 1); + const unsigned short* r2 = bottom_blob.row(i * 4 + 2); + const unsigned short* r3 = bottom_blob.row(i * 4 + 3); + + unsigned short* outptr = top_blob.row(i); + + int j = 0; +#if __SSE2__ + for (; j + 3 < w; j += 4) + { + // transpose 4x4 unsigned short + __m128i _r0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)r3); + + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + + __m128i _r0123l = _mm_unpacklo_epi32(_r01, _r23); + __m128i _r0123h = _mm_unpackhi_epi32(_r01, _r23); + + _mm_storeu_si128((__m128i*)outptr, _r0123l); + _mm_storeu_si128((__m128i*)(outptr + 8), _r0123h); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 16; + } +#endif // __SSE2__ + for (; j < w; j++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + + outptr += 4; + } + } + } + if (pack4to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 4); + unsigned short* outptr1 = top_blob.row(i * 4 + 1); + unsigned short* outptr2 = top_blob.row(i * 4 + 2); + unsigned short* outptr3 = top_blob.row(i * 4 + 3); + + int j = 0; +#if __SSE2__ + for (; j + 3 < w; j += 4) + { + // transpose 4x4 unsigned short + __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); + + __m128i _r01l = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r01h = _mm_unpackhi_epi16(_r0, _r1); + + __m128i _r0123ll = _mm_unpacklo_epi16(_r01l, _r01h); + __m128i _r0123lh = _mm_unpackhi_epi16(_r01l, _r01h); + + _mm_storel_epi64((__m128i*)outptr0, _r0123ll); + _mm_storel_epi64((__m128i*)outptr1, _mm_srli_si128(_r0123ll, 8)); + _mm_storel_epi64((__m128i*)outptr2, _r0123lh); + _mm_storel_epi64((__m128i*)outptr3, _mm_srli_si128(_r0123lh, 8)); + + r0 += 16; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } +#endif // __SSE2__ + for (; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + + r0 += 4; + } + } + } + if (pack1to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 8); + const unsigned short* r1 = bottom_blob.row(i * 8 + 1); + const unsigned short* r2 = bottom_blob.row(i * 8 + 2); + const unsigned short* r3 = bottom_blob.row(i * 8 + 3); + const unsigned short* r4 = bottom_blob.row(i * 8 + 4); + const unsigned short* r5 = bottom_blob.row(i * 8 + 5); + const unsigned short* r6 = bottom_blob.row(i * 8 + 6); + const unsigned short* r7 = bottom_blob.row(i * 8 + 7); + + unsigned short* outptr = top_blob.row(i); + + int j = 0; + for (; j < w; j++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + outptr[4] = *r4++; + outptr[5] = *r5++; + outptr[6] = *r6++; + outptr[7] = *r7++; + + outptr += 8; + } + } + } + if (pack8to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 8); + unsigned short* outptr1 = top_blob.row(i * 8 + 1); + unsigned short* outptr2 = top_blob.row(i * 8 + 2); + unsigned short* outptr3 = top_blob.row(i * 8 + 3); + unsigned short* outptr4 = top_blob.row(i * 8 + 4); + unsigned short* outptr5 = top_blob.row(i * 8 + 5); + unsigned short* outptr6 = top_blob.row(i * 8 + 6); + unsigned short* outptr7 = top_blob.row(i * 8 + 7); + + int j = 0; + for (; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + + r0 += 8; + } + } + } + if (pack4to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 2); + const unsigned short* r1 = bottom_blob.row(i * 2 + 1); + + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < w; j++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r1[0]; + outptr[5] = r1[1]; + outptr[6] = r1[2]; + outptr[7] = r1[3]; + + r0 += 4; + r1 += 4; + outptr += 8; + } + } + } + if (pack8to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 2); + unsigned short* outptr1 = top_blob.row(i * 2 + 1); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + + r0 += 8; + outptr0 += 4; + outptr1 += 4; + } + } + } + if (pack1to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 16); + const unsigned short* r1 = bottom_blob.row(i * 16 + 1); + const unsigned short* r2 = bottom_blob.row(i * 16 + 2); + const unsigned short* r3 = bottom_blob.row(i * 16 + 3); + const unsigned short* r4 = bottom_blob.row(i * 16 + 4); + const unsigned short* r5 = bottom_blob.row(i * 16 + 5); + const unsigned short* r6 = bottom_blob.row(i * 16 + 6); + const unsigned short* r7 = bottom_blob.row(i * 16 + 7); + const unsigned short* r8 = bottom_blob.row(i * 16 + 8); + const unsigned short* r9 = bottom_blob.row(i * 16 + 9); + const unsigned short* ra = bottom_blob.row(i * 16 + 10); + const unsigned short* rb = bottom_blob.row(i * 16 + 11); + const unsigned short* rc = bottom_blob.row(i * 16 + 12); + const unsigned short* rd = bottom_blob.row(i * 16 + 13); + const unsigned short* re = bottom_blob.row(i * 16 + 14); + const unsigned short* rf = bottom_blob.row(i * 16 + 15); + + unsigned short* outptr = top_blob.row(i); + + int j = 0; + for (; j < w; j++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + outptr[4] = *r4++; + outptr[5] = *r5++; + outptr[6] = *r6++; + outptr[7] = *r7++; + outptr[8] = *r8++; + outptr[9] = *r9++; + outptr[10] = *ra++; + outptr[11] = *rb++; + outptr[12] = *rc++; + outptr[13] = *rd++; + outptr[14] = *re++; + outptr[15] = *rf++; + + outptr += 16; + } + } + } + if (pack16to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 16); + unsigned short* outptr1 = top_blob.row(i * 16 + 1); + unsigned short* outptr2 = top_blob.row(i * 16 + 2); + unsigned short* outptr3 = top_blob.row(i * 16 + 3); + unsigned short* outptr4 = top_blob.row(i * 16 + 4); + unsigned short* outptr5 = top_blob.row(i * 16 + 5); + unsigned short* outptr6 = top_blob.row(i * 16 + 6); + unsigned short* outptr7 = top_blob.row(i * 16 + 7); + unsigned short* outptr8 = top_blob.row(i * 16 + 8); + unsigned short* outptr9 = top_blob.row(i * 16 + 9); + unsigned short* outptra = top_blob.row(i * 16 + 10); + unsigned short* outptrb = top_blob.row(i * 16 + 11); + unsigned short* outptrc = top_blob.row(i * 16 + 12); + unsigned short* outptrd = top_blob.row(i * 16 + 13); + unsigned short* outptre = top_blob.row(i * 16 + 14); + unsigned short* outptrf = top_blob.row(i * 16 + 15); + + int j = 0; + for (; j < w; j++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + *outptr8++ = r0[8]; + *outptr9++ = r0[9]; + *outptra++ = r0[10]; + *outptrb++ = r0[11]; + *outptrc++ = r0[12]; + *outptrd++ = r0[13]; + *outptre++ = r0[14]; + *outptrf++ = r0[15]; + + r0 += 16; + } + } + } + if (pack4to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 4); + const unsigned short* r1 = bottom_blob.row(i * 4 + 1); + const unsigned short* r2 = bottom_blob.row(i * 4 + 2); + const unsigned short* r3 = bottom_blob.row(i * 4 + 3); + + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < w; j++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r1[0]; + outptr[5] = r1[1]; + outptr[6] = r1[2]; + outptr[7] = r1[3]; + outptr[8] = r2[0]; + outptr[9] = r2[1]; + outptr[10] = r2[2]; + outptr[11] = r2[3]; + outptr[12] = r3[0]; + outptr[13] = r3[1]; + outptr[14] = r3[2]; + outptr[15] = r3[3]; + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 16; + } + } + } + if (pack16to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 4); + unsigned short* outptr1 = top_blob.row(i * 4 + 1); + unsigned short* outptr2 = top_blob.row(i * 4 + 2); + unsigned short* outptr3 = top_blob.row(i * 4 + 3); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + outptr2[0] = r0[8]; + outptr2[1] = r0[9]; + outptr2[2] = r0[10]; + outptr2[3] = r0[11]; + outptr3[0] = r0[12]; + outptr3[1] = r0[13]; + outptr3[2] = r0[14]; + outptr3[3] = r0[15]; + + r0 += 16; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } + } + } + if (pack8to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* r0 = bottom_blob.row(i * 2); + const unsigned short* r1 = bottom_blob.row(i * 2 + 1); + + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < w; j++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r0[4]; + outptr[5] = r0[5]; + outptr[6] = r0[6]; + outptr[7] = r0[7]; + outptr[8] = r1[0]; + outptr[9] = r1[1]; + outptr[10] = r1[2]; + outptr[11] = r1[3]; + outptr[12] = r1[4]; + outptr[13] = r1[5]; + outptr[14] = r1[6]; + outptr[15] = r1[7]; + + r0 += 8; + r1 += 8; + outptr += 16; + } + } + } + if (pack16to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* r0 = bottom_blob.row(i); + + unsigned short* outptr0 = top_blob.row(i * 2); + unsigned short* outptr1 = top_blob.row(i * 2 + 1); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr1[0] = r0[8]; + outptr1[1] = r0[9]; + outptr1[2] = r0[10]; + outptr1[3] = r0[11]; + outptr1[4] = r0[12]; + outptr1[5] = r0[13]; + outptr1[6] = r0[14]; + outptr1[7] = r0[15]; + + r0 += 16; + outptr0 += 8; + outptr1 += 8; + } + } + } + + return 0; + } + + if (dims == 3 || dims == 4) + { + int size = w * h * d; + int outc = channels * elempack / out_elempack; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (dims == 3) + top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator); + else // if (dims == 4) + top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (pack1to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 4); + const unsigned short* r1 = bottom_blob.channel(q * 4 + 1); + const unsigned short* r2 = bottom_blob.channel(q * 4 + 2); + const unsigned short* r3 = bottom_blob.channel(q * 4 + 3); + + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ + for (; i + 3 < size; i += 4) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)r3); + + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + + __m128i _r0123l = _mm_unpacklo_epi32(_r01, _r23); + __m128i _r0123h = _mm_unpackhi_epi32(_r01, _r23); + + _mm_storeu_si128((__m128i*)outptr, _r0123l); + _mm_storeu_si128((__m128i*)(outptr + 8), _r0123h); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 16; + } +#endif // __SSE2__ + for (; i < size; i++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + + outptr += 4; + } + } + } + if (pack4to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 4); + unsigned short* outptr1 = top_blob.channel(q * 4 + 1); + unsigned short* outptr2 = top_blob.channel(q * 4 + 2); + unsigned short* outptr3 = top_blob.channel(q * 4 + 3); + + int i = 0; +#if __SSE2__ + for (; i + 3 < size; i += 4) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); + + __m128i _r01l = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r01h = _mm_unpackhi_epi16(_r0, _r1); + + __m128i _r0123ll = _mm_unpacklo_epi16(_r01l, _r01h); + __m128i _r0123lh = _mm_unpackhi_epi16(_r01l, _r01h); + + _mm_storel_epi64((__m128i*)outptr0, _r0123ll); + _mm_storel_epi64((__m128i*)outptr1, _mm_srli_si128(_r0123ll, 8)); + _mm_storel_epi64((__m128i*)outptr2, _r0123lh); + _mm_storel_epi64((__m128i*)outptr3, _mm_srli_si128(_r0123lh, 8)); + + r0 += 16; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + + r0 += 4; + } + } + } + if (pack1to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 8); + const unsigned short* r1 = bottom_blob.channel(q * 8 + 1); + const unsigned short* r2 = bottom_blob.channel(q * 8 + 2); + const unsigned short* r3 = bottom_blob.channel(q * 8 + 3); + const unsigned short* r4 = bottom_blob.channel(q * 8 + 4); + const unsigned short* r5 = bottom_blob.channel(q * 8 + 5); + const unsigned short* r6 = bottom_blob.channel(q * 8 + 6); + const unsigned short* r7 = bottom_blob.channel(q * 8 + 7); + + unsigned short* outptr = top_blob.channel(q); + + int i = 0; + for (; i < size; i++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + outptr[4] = *r4++; + outptr[5] = *r5++; + outptr[6] = *r6++; + outptr[7] = *r7++; + + outptr += 8; + } + } + } + if (pack8to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 8); + unsigned short* outptr1 = top_blob.channel(q * 8 + 1); + unsigned short* outptr2 = top_blob.channel(q * 8 + 2); + unsigned short* outptr3 = top_blob.channel(q * 8 + 3); + unsigned short* outptr4 = top_blob.channel(q * 8 + 4); + unsigned short* outptr5 = top_blob.channel(q * 8 + 5); + unsigned short* outptr6 = top_blob.channel(q * 8 + 6); + unsigned short* outptr7 = top_blob.channel(q * 8 + 7); + + int i = 0; + for (; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + + r0 += 8; + } + } + } + if (pack4to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 2); + const unsigned short* r1 = bottom_blob.channel(q * 2 + 1); + + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r1[0]; + outptr[5] = r1[1]; + outptr[6] = r1[2]; + outptr[7] = r1[3]; + + r0 += 4; + r1 += 4; + outptr += 8; + } + } + } + if (pack8to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 2); + unsigned short* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + + r0 += 8; + outptr0 += 4; + outptr1 += 4; + } + } + } + if (pack1to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 16); + const unsigned short* r1 = bottom_blob.channel(q * 16 + 1); + const unsigned short* r2 = bottom_blob.channel(q * 16 + 2); + const unsigned short* r3 = bottom_blob.channel(q * 16 + 3); + const unsigned short* r4 = bottom_blob.channel(q * 16 + 4); + const unsigned short* r5 = bottom_blob.channel(q * 16 + 5); + const unsigned short* r6 = bottom_blob.channel(q * 16 + 6); + const unsigned short* r7 = bottom_blob.channel(q * 16 + 7); + const unsigned short* r8 = bottom_blob.channel(q * 16 + 8); + const unsigned short* r9 = bottom_blob.channel(q * 16 + 9); + const unsigned short* ra = bottom_blob.channel(q * 16 + 10); + const unsigned short* rb = bottom_blob.channel(q * 16 + 11); + const unsigned short* rc = bottom_blob.channel(q * 16 + 12); + const unsigned short* rd = bottom_blob.channel(q * 16 + 13); + const unsigned short* re = bottom_blob.channel(q * 16 + 14); + const unsigned short* rf = bottom_blob.channel(q * 16 + 15); + + unsigned short* outptr = top_blob.channel(q); + + int i = 0; + for (; i < size; i++) + { + outptr[0] = *r0++; + outptr[1] = *r1++; + outptr[2] = *r2++; + outptr[3] = *r3++; + outptr[4] = *r4++; + outptr[5] = *r5++; + outptr[6] = *r6++; + outptr[7] = *r7++; + outptr[8] = *r8++; + outptr[9] = *r9++; + outptr[10] = *ra++; + outptr[11] = *rb++; + outptr[12] = *rc++; + outptr[13] = *rd++; + outptr[14] = *re++; + outptr[15] = *rf++; + + outptr += 16; + } + } + } + if (pack16to1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 16); + unsigned short* outptr1 = top_blob.channel(q * 16 + 1); + unsigned short* outptr2 = top_blob.channel(q * 16 + 2); + unsigned short* outptr3 = top_blob.channel(q * 16 + 3); + unsigned short* outptr4 = top_blob.channel(q * 16 + 4); + unsigned short* outptr5 = top_blob.channel(q * 16 + 5); + unsigned short* outptr6 = top_blob.channel(q * 16 + 6); + unsigned short* outptr7 = top_blob.channel(q * 16 + 7); + unsigned short* outptr8 = top_blob.channel(q * 16 + 8); + unsigned short* outptr9 = top_blob.channel(q * 16 + 9); + unsigned short* outptra = top_blob.channel(q * 16 + 10); + unsigned short* outptrb = top_blob.channel(q * 16 + 11); + unsigned short* outptrc = top_blob.channel(q * 16 + 12); + unsigned short* outptrd = top_blob.channel(q * 16 + 13); + unsigned short* outptre = top_blob.channel(q * 16 + 14); + unsigned short* outptrf = top_blob.channel(q * 16 + 15); + + int i = 0; + for (; i < size; i++) + { + *outptr0++ = r0[0]; + *outptr1++ = r0[1]; + *outptr2++ = r0[2]; + *outptr3++ = r0[3]; + *outptr4++ = r0[4]; + *outptr5++ = r0[5]; + *outptr6++ = r0[6]; + *outptr7++ = r0[7]; + *outptr8++ = r0[8]; + *outptr9++ = r0[9]; + *outptra++ = r0[10]; + *outptrb++ = r0[11]; + *outptrc++ = r0[12]; + *outptrd++ = r0[13]; + *outptre++ = r0[14]; + *outptrf++ = r0[15]; + + r0 += 16; + } + } + } + if (pack4to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 4); + const unsigned short* r1 = bottom_blob.channel(q * 4 + 1); + const unsigned short* r2 = bottom_blob.channel(q * 4 + 2); + const unsigned short* r3 = bottom_blob.channel(q * 4 + 3); + + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r1[0]; + outptr[5] = r1[1]; + outptr[6] = r1[2]; + outptr[7] = r1[3]; + outptr[8] = r2[0]; + outptr[9] = r2[1]; + outptr[10] = r2[2]; + outptr[11] = r2[3]; + outptr[12] = r3[0]; + outptr[13] = r3[1]; + outptr[14] = r3[2]; + outptr[15] = r3[3]; + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 16; + } + } + } + if (pack16to4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 4); + unsigned short* outptr1 = top_blob.channel(q * 4 + 1); + unsigned short* outptr2 = top_blob.channel(q * 4 + 2); + unsigned short* outptr3 = top_blob.channel(q * 4 + 3); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr1[0] = r0[4]; + outptr1[1] = r0[5]; + outptr1[2] = r0[6]; + outptr1[3] = r0[7]; + outptr2[0] = r0[8]; + outptr2[1] = r0[9]; + outptr2[2] = r0[10]; + outptr2[3] = r0[11]; + outptr3[0] = r0[12]; + outptr3[1] = r0[13]; + outptr3[2] = r0[14]; + outptr3[3] = r0[15]; + + r0 += 16; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } + } + } + if (pack8to16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* r0 = bottom_blob.channel(q * 2); + const unsigned short* r1 = bottom_blob.channel(q * 2 + 1); + + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = r0[0]; + outptr[1] = r0[1]; + outptr[2] = r0[2]; + outptr[3] = r0[3]; + outptr[4] = r0[4]; + outptr[5] = r0[5]; + outptr[6] = r0[6]; + outptr[7] = r0[7]; + outptr[8] = r1[0]; + outptr[9] = r1[1]; + outptr[10] = r1[2]; + outptr[11] = r1[3]; + outptr[12] = r1[4]; + outptr[13] = r1[5]; + outptr[14] = r1[6]; + outptr[15] = r1[7]; + + r0 += 8; + r1 += 8; + outptr += 16; + } + } + } + if (pack16to8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* r0 = bottom_blob.channel(q); + + unsigned short* outptr0 = top_blob.channel(q * 2); + unsigned short* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr1[0] = r0[8]; + outptr1[1] = r0[9]; + outptr1[2] = r0[10]; + outptr1[3] = r0[11]; + outptr1[4] = r0[12]; + outptr1[5] = r0[13]; + outptr1[6] = r0[14]; + outptr1[7] = r0[15]; + + r0 += 16; + outptr0 += 8; + outptr1 += 8; + } + } + } + + return 0; + } + + return 0; +} + int Packing_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { if (use_padding) diff --git a/src/layer/x86/packing_x86.h b/src/layer/x86/packing_x86.h index e32a33c33ac..feded196da9 100644 --- a/src/layer/x86/packing_x86.h +++ b/src/layer/x86/packing_x86.h @@ -16,6 +16,7 @@ class Packing_x86 : public Packing virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; protected: + int forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; }; diff --git a/src/layer/x86/padding_pack16_bf16s_fp16s.h b/src/layer/x86/padding_pack16_bf16s_fp16s.h new file mode 100644 index 00000000000..13ae29df1ec --- /dev/null +++ b/src/layer/x86/padding_pack16_bf16s_fp16s.h @@ -0,0 +1,201 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +static void padding_constant_pack16_bf16s_fp16s_avx512(const Mat& src, Mat& dst, int top, int bottom, int left, int right, const __m256i& v) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + int top_size = top * dst.w; + int bottom_size = bottom * dst.w; + + // fill top + for (int y = 0; y < top_size; y++) + { + _mm256_storeu_si256((__m256i*)outptr, v); + outptr += 16; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + _mm256_storeu_si256((__m256i*)outptr, v); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _mm256_loadu_si256((const __m256i*)ptr)); + ptr += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + _mm256_storeu_si256((__m256i*)outptr, v); + outptr += 16; + } + } + // fill bottom + for (int y = 0; y < bottom_size; y++) + { + _mm256_storeu_si256((__m256i*)outptr, v); + outptr += 16; + } +} + +static void padding_replicate_pack16_bf16s_fp16s_avx512(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr0); + for (int x = 0; x < left; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm256_loadu_si256((const __m256i*)ptr0); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr0 += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + } + // fill center + for (int y = 0; y < src.h; y++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr); + for (int x = 0; x < left; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm256_loadu_si256((const __m256i*)ptr); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + } + // fill bottom + ptr -= src.w * 16; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr0); + for (int x = 0; x < left; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm256_loadu_si256((const __m256i*)ptr0); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr0 += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + } +} + +static void padding_reflect_pack16_bf16s_fp16s_avx512(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + ptr += top * src.w * 16; + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr0 + (left - x) * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr0); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr0 += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr0 - 32 - x * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + ptr -= src.w * 16; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr + (left - x) * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr - 32 - x * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + } + // fill bottom + ptr -= 2 * src.w * 16; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr0 + (left - x) * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + for (int x = 0; x < src.w; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)ptr0); + _mm256_storeu_si256((__m256i*)outptr, _p); + ptr0 += 16; + outptr += 16; + } + for (int x = 0; x < right; x++) + { + __m256i _p = _mm256_loadu_si256((const __m256i*)(ptr0 - 32 - x * 16)); + _mm256_storeu_si256((__m256i*)outptr, _p); + outptr += 16; + } + ptr -= src.w * 16; + } +} diff --git a/src/layer/x86/padding_pack4_bf16s_fp16s.h b/src/layer/x86/padding_pack4_bf16s_fp16s.h new file mode 100644 index 00000000000..81ffb9d17ee --- /dev/null +++ b/src/layer/x86/padding_pack4_bf16s_fp16s.h @@ -0,0 +1,192 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +static void padding_constant_pack4_bf16s_fp16s_sse(const Mat& src, Mat& dst, int top, int bottom, int left, int right, int64_t v) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + int top_size = top * dst.w; + int bottom_size = bottom * dst.w; + + // fill top + for (int y = 0; y < top_size; y++) + { + *(int64_t*)outptr = v; + outptr += 4; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = v; + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + *(int64_t*)outptr = *(const int64_t*)ptr; + ptr += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = v; + outptr += 4; + } + } + // fill bottom + for (int y = 0; y < bottom_size; y++) + { + *(int64_t*)outptr = v; + outptr += 4; + } +} + +static void padding_replicate_pack4_bf16s_fp16s_sse(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + int64_t _p = *(const int64_t*)ptr0; + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + _p = *(const int64_t*)ptr0; + *(int64_t*)outptr = _p; + ptr0 += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + } + // fill center + for (int y = 0; y < src.h; y++) + { + int64_t _p = *(const int64_t*)ptr; + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + _p = *(const int64_t*)ptr; + *(int64_t*)outptr = _p; + ptr += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + } + // fill bottom + ptr -= src.w * 4; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + int64_t _p = *(const int64_t*)ptr0; + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + _p = *(const int64_t*)ptr0; + *(int64_t*)outptr = _p; + ptr0 += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = _p; + outptr += 4; + } + } +} + +static void padding_reflect_pack4_bf16s_fp16s_sse(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + ptr += top * src.w * 4; + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr0 + (left - x) * 4); + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + *(int64_t*)outptr = *(const int64_t*)ptr0; + ptr0 += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr0 - 8 - x * 4); + outptr += 4; + } + ptr -= src.w * 4; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr + (left - x) * 4); + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + *(int64_t*)outptr = *(const int64_t*)ptr; + ptr += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr - 8 - x * 4); + outptr += 4; + } + } + // fill bottom + ptr -= 2 * src.w * 4; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr0 + (left - x) * 4); + outptr += 4; + } + for (int x = 0; x < src.w; x++) + { + *(int64_t*)outptr = *(const int64_t*)ptr0; + ptr0 += 4; + outptr += 4; + } + for (int x = 0; x < right; x++) + { + *(int64_t*)outptr = *(const int64_t*)(ptr0 - 8 - x * 4); + outptr += 4; + } + ptr -= src.w * 4; + } +} diff --git a/src/layer/x86/padding_pack8_bf16s_fp16s.h b/src/layer/x86/padding_pack8_bf16s_fp16s.h new file mode 100644 index 00000000000..fc958452a39 --- /dev/null +++ b/src/layer/x86/padding_pack8_bf16s_fp16s.h @@ -0,0 +1,201 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +static void padding_constant_pack8_bf16s_fp16s_avx(const Mat& src, Mat& dst, int top, int bottom, int left, int right, const __m128i& v) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + int top_size = top * dst.w; + int bottom_size = bottom * dst.w; + + // fill top + for (int y = 0; y < top_size; y++) + { + _mm_storeu_si128((__m128i*)outptr, v); + outptr += 8; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + _mm_storeu_si128((__m128i*)outptr, v); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + _mm_storeu_si128((__m128i*)outptr, _mm_loadu_si128((const __m128i*)ptr)); + ptr += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + _mm_storeu_si128((__m128i*)outptr, v); + outptr += 8; + } + } + // fill bottom + for (int y = 0; y < bottom_size; y++) + { + _mm_storeu_si128((__m128i*)outptr, v); + outptr += 8; + } +} + +static void padding_replicate_pack8_bf16s_fp16s_avx(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + __m128i _p = _mm_loadu_si128((const __m128i*)ptr0); + for (int x = 0; x < left; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm_loadu_si128((const __m128i*)ptr0); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr0 += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + } + // fill center + for (int y = 0; y < src.h; y++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)ptr); + for (int x = 0; x < left; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm_loadu_si128((const __m128i*)ptr); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + } + // fill bottom + ptr -= src.w * 8; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + __m128i _p = _mm_loadu_si128((const __m128i*)ptr0); + for (int x = 0; x < left; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + _p = _mm_loadu_si128((const __m128i*)ptr0); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr0 += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + } +} + +static void padding_reflect_pack8_bf16s_fp16s_avx(const Mat& src, Mat& dst, int top, int bottom, int left, int right) +{ + const unsigned short* ptr = src; + unsigned short* outptr = dst; + + // fill top + ptr += top * src.w * 8; + for (int y = 0; y < top; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr0 + (left - x) * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)ptr0); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr0 += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr0 - 16 - x * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + ptr -= src.w * 8; + } + // fill center + for (int y = 0; y < src.h; y++) + { + for (int x = 0; x < left; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr + (left - x) * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)ptr); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr - 16 - x * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + } + // fill bottom + ptr -= 2 * src.w * 8; + for (int y = 0; y < bottom; y++) + { + const unsigned short* ptr0 = ptr; + for (int x = 0; x < left; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr0 + (left - x) * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + for (int x = 0; x < src.w; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)ptr0); + _mm_storeu_si128((__m128i*)outptr, _p); + ptr0 += 8; + outptr += 8; + } + for (int x = 0; x < right; x++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)(ptr0 - 16 - x * 8)); + _mm_storeu_si128((__m128i*)outptr, _p); + outptr += 8; + } + ptr -= src.w * 8; + } +} diff --git a/src/layer/x86/padding_x86.cpp b/src/layer/x86/padding_x86.cpp index d46bb55157c..7e1eaf5b1b4 100644 --- a/src/layer/x86/padding_x86.cpp +++ b/src/layer/x86/padding_x86.cpp @@ -1,8 +1,10 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "padding_x86.h" +#include "cpu.h" + #include #if __SSE2__ #include @@ -15,11 +17,14 @@ namespace ncnn { #if __SSE2__ #include "padding_pack4.h" +#include "padding_pack4_bf16s_fp16s.h" #include "padding_pack8_int8.h" #if __AVX__ #include "padding_pack8.h" +#include "padding_pack8_bf16s_fp16s.h" #if __AVX512F__ #include "padding_pack16.h" +#include "padding_pack16_bf16s_fp16s.h" #endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ @@ -29,6 +34,42 @@ Padding_x86::Padding_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +int Padding_x86::create_pipeline(const Option& opt) +{ + if (support_fp16_storage && opt.use_fp16_storage) + { + value_fp16 = float32_to_float16(value); + + ncnn::cast_float32_to_float16(per_channel_pad_data, per_channel_pad_data_fp16, opt); + } + +#if NCNN_BF16 + if (opt.use_bf16_storage) + { + value_bf16 = float32_to_bfloat16(value); + + ncnn::cast_float32_to_bfloat16(per_channel_pad_data, per_channel_pad_data_bf16, opt); + } +#endif + + return 0; +} + +int Padding_x86::destroy_pipeline(const Option& /*opt*/) +{ + per_channel_pad_data_fp16 = Mat(); + +#if NCNN_BF16 + per_channel_pad_data_bf16 = Mat(); +#endif + + return 0; } int Padding_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const @@ -44,6 +85,9 @@ int Padding_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op if (elembits == 8) return forward_int8(bottom_blob, top_blob, opt); + if (elembits == 16) + return forward_bf16s_fp16s(bottom_blob, top_blob, opt); + int w = bottom_blob.w; int h = bottom_blob.h; int d = bottom_blob.d; @@ -459,6 +503,468 @@ int Padding_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op return Padding::forward(bottom_blob_unpacked, top_blob, opt); } +int Padding_x86::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + unsigned short pad_value_bf16_fp16 = 0; + if (support_fp16_storage && opt.use_fp16_storage) + { + pad_value_bf16_fp16 = value_fp16; + } + else +#if NCNN_BF16 + if (opt.use_bf16_storage) + { + pad_value_bf16_fp16 = value_bf16; + } + else +#endif + { + pad_value_bf16_fp16 = 0; + } + + const Mat& per_channel_pad_data_bf16_fp16 = (support_fp16_storage && opt.use_fp16_storage) ? per_channel_pad_data_fp16 : +#if NCNN_BF16 + opt.use_bf16_storage ? per_channel_pad_data_bf16 : +#endif + Mat(); + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + if (dims == 1) + { + int outw = w * elempack + left + right; + + int out_elempack = outw % 16 == 0 ? 16 : outw % 8 == 0 ? 8 : outw % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (left % 16 == 0 && out_elempack == 16 && type == 0) + { + top_blob.create(outw / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + __m256i pad_value = _mm256_set1_epi16(pad_value_bf16_fp16); + padding_constant_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, 0, 0, left / 16, right / 16, pad_value); + + return 0; + } + } + + if (dims == 2) + { + int outw = w + left + right; + int outh = h * elempack + top + bottom; + + int out_elempack = outh % 16 == 0 ? 16 : outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (top % 16 == 0 && out_elempack == 16 && type == 0) + { + top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + __m256i pad_value = _mm256_set1_epi16(pad_value_bf16_fp16); + padding_constant_pack16_bf16s_fp16s_avx512(bottom_blob, top_blob, top / 16, bottom / 16, left, right, pad_value); + + return 0; + } + } + + if (dims == 3) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outc = channels * elempack + front + behind; + + int out_elempack = outc % 16 == 0 ? 16 : outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (front % 16 == 0 && out_elempack == 16 && !(outc != channels * elempack && type != 0)) + { + top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int front_ = front / elempack; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc / out_elempack; q++) + { + Mat borderm = top_blob.channel(q); + + __m256i pad_value = per_channel_pad_data_size ? _mm256_loadu_si256((const __m256i*)((const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 16)) : _mm256_set1_epi16(pad_value_bf16_fp16); + //Channel padding + if ((q - front_) < 0 || (q - front_) >= channels) + { + borderm.fill(pad_value); + } + else + { + const Mat m = bottom_blob.channel(q - front_); + if (type == 0) + padding_constant_pack16_bf16s_fp16s_avx512(m, borderm, top, bottom, left, right, pad_value); + if (type == 1) + padding_replicate_pack16_bf16s_fp16s_avx512(m, borderm, top, bottom, left, right); + if (type == 2) + padding_reflect_pack16_bf16s_fp16s_avx512(m, borderm, top, bottom, left, right); + } + } + + return 0; + } + } + + if (dims == 4) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outd = d + front + behind; + + if (type == 0) + { + top_blob.create(outw, outh, outd, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __m256i pad_value = per_channel_pad_data_size ? _mm256_loadu_si256((const __m256i*)((const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 16)) : _mm256_set1_epi16(pad_value_bf16_fp16); + + for (int z = 0; z < outd; z++) + { + Mat borderm = top_blob.channel(q).depth(z); + + // depth padding + if ((z - front) < 0 || (z - front) >= d) + { + borderm.fill(pad_value); + } + else + { + const Mat m = bottom_blob.channel(q).depth(z - front); + padding_constant_pack16_bf16s_fp16s_avx512(m, borderm, top, bottom, left, right, pad_value); + } + } + } + + return 0; + } + } + } +#endif // __AVX512F__ + + if (elempack == 8) + { + if (dims == 1) + { + int outw = w * elempack + left + right; + + int out_elempack = outw % 8 == 0 ? 8 : outw % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (left % 8 == 0 && out_elempack == 8 && type == 0) + { + top_blob.create(outw / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + __m128i pad_value = _mm_set1_epi16(pad_value_bf16_fp16); + padding_constant_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, 0, 0, left / 8, right / 8, pad_value); + + return 0; + } + } + + if (dims == 2) + { + int outw = w + left + right; + int outh = h * elempack + top + bottom; + + int out_elempack = outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (top % 8 == 0 && out_elempack == 8 && type == 0) + { + top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + __m128i pad_value = _mm_set1_epi16(pad_value_bf16_fp16); + padding_constant_pack8_bf16s_fp16s_avx(bottom_blob, top_blob, top / 8, bottom / 8, left, right, pad_value); + + return 0; + } + } + + if (dims == 3) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outc = channels * elempack + front + behind; + + int out_elempack = outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1; + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (front % 8 == 0 && out_elempack == 8 && !(outc != channels * elempack && type != 0)) + { + top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int front_ = front / elempack; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc / out_elempack; q++) + { + Mat borderm = top_blob.channel(q); + + __m128i pad_value = per_channel_pad_data_size ? _mm_loadu_si128((const __m128i*)((const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 8)) : _mm_set1_epi16(pad_value_bf16_fp16); + //Channel padding + if ((q - front_) < 0 || (q - front_) >= channels) + { + borderm.fill(pad_value); + } + else + { + const Mat m = bottom_blob.channel(q - front_); + if (type == 0) + padding_constant_pack8_bf16s_fp16s_avx(m, borderm, top, bottom, left, right, pad_value); + if (type == 1) + padding_replicate_pack8_bf16s_fp16s_avx(m, borderm, top, bottom, left, right); + if (type == 2) + padding_reflect_pack8_bf16s_fp16s_avx(m, borderm, top, bottom, left, right); + } + } + + return 0; + } + } + + if (dims == 4) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outd = d + front + behind; + + if (type == 0) + { + top_blob.create(outw, outh, outd, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __m128i pad_value = per_channel_pad_data_size ? _mm_loadu_si128((const __m128i*)((const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 8)) : _mm_set1_epi16(pad_value_bf16_fp16); + + for (int z = 0; z < outd; z++) + { + Mat borderm = top_blob.channel(q).depth(z); + + // depth padding + if ((z - front) < 0 || (z - front) >= d) + { + borderm.fill(pad_value); + } + else + { + const Mat m = bottom_blob.channel(q).depth(z - front); + padding_constant_pack8_bf16s_fp16s_avx(m, borderm, top, bottom, left, right, pad_value); + } + } + } + + return 0; + } + } + } +#endif // __AVX__ + + if (elempack == 4) + { + if (dims == 1) + { + int outw = w * elempack + left + right; + +#if __AVX__ + int out_elempack = outw % 8 == 0 ? 8 : outw % 4 == 0 ? 4 : 1; +#else + int out_elempack = outw % 4 == 0 ? 4 : 1; +#endif + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (left % 4 == 0 && out_elempack == 4 && type == 0) + { + top_blob.create(outw / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int64_t v16 = (int64_t)(unsigned short)pad_value_bf16_fp16; + int64_t pad_value_i64 = v16 | (v16 << 16) | (v16 << 32) | (v16 << 48); + padding_constant_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, 0, 0, left / 4, right / 4, pad_value_i64); + + return 0; + } + } + + if (dims == 2) + { + int outw = w + left + right; + int outh = h * elempack + top + bottom; + +#if __AVX__ + int out_elempack = outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; +#else + int out_elempack = outh % 4 == 0 ? 4 : 1; +#endif + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (top % 4 == 0 && out_elempack == 4 && type == 0) + { + top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int64_t v16 = (int64_t)(unsigned short)pad_value_bf16_fp16; + int64_t pad_value_i64 = v16 | (v16 << 16) | (v16 << 32) | (v16 << 48); + padding_constant_pack4_bf16s_fp16s_sse(bottom_blob, top_blob, top / 4, bottom / 4, left, right, pad_value_i64); + + return 0; + } + } + + if (dims == 3) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outc = channels * elempack + front + behind; + +#if __AVX__ + int out_elempack = outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1; +#else + int out_elempack = outc % 4 == 0 ? 4 : 1; +#endif + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (front % 4 == 0 && out_elempack == 4 && !(outc != channels * elempack && type != 0)) + { + top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + int front_ = front / elempack; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc / out_elempack; q++) + { + Mat borderm = top_blob.channel(q); + + int64_t pad_value_i64; + if (per_channel_pad_data_size) + { + const unsigned short* p = (const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 4; + pad_value_i64 = *(const int64_t*)p; + } + else + { + int64_t v16 = (int64_t)(unsigned short)pad_value_bf16_fp16; + pad_value_i64 = v16 | (v16 << 16) | (v16 << 32) | (v16 << 48); + } + //Channel padding + if ((q - front_) < 0 || (q - front_) >= channels) + { + borderm.fill(pad_value_i64); + } + else + { + const Mat m = bottom_blob.channel(q - front_); + if (type == 0) + padding_constant_pack4_bf16s_fp16s_sse(m, borderm, top, bottom, left, right, pad_value_i64); + if (type == 1) + padding_replicate_pack4_bf16s_fp16s_sse(m, borderm, top, bottom, left, right); + if (type == 2) + padding_reflect_pack4_bf16s_fp16s_sse(m, borderm, top, bottom, left, right); + } + } + + return 0; + } + } + + if (dims == 4) + { + int outw = w + left + right; + int outh = h + top + bottom; + int outd = d + front + behind; + + if (type == 0) + { + top_blob.create(outw, outh, outd, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + int64_t pad_value_i64; + if (per_channel_pad_data_size) + { + const unsigned short* p = (const unsigned short*)per_channel_pad_data_bf16_fp16 + q * 4; + pad_value_i64 = *(const int64_t*)p; + } + else + { + int64_t v16 = (int64_t)(unsigned short)pad_value_bf16_fp16; + pad_value_i64 = v16 | (v16 << 16) | (v16 << 32) | (v16 << 48); + } + + for (int z = 0; z < outd; z++) + { + Mat borderm = top_blob.channel(q).depth(z); + + // depth padding + if ((z - front) < 0 || (z - front) >= d) + { + borderm.fill(pad_value_i64); + } + else + { + const Mat m = bottom_blob.channel(q).depth(z - front); + padding_constant_pack4_bf16s_fp16s_sse(m, borderm, top, bottom, left, right, pad_value_i64); + } + } + } + + return 0; + } + } + } +#endif // __SSE2__ + + Mat bottom_blob_unpacked = bottom_blob; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; + + convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } + + return Padding::forward(bottom_blob_unpacked, top_blob, opt); +} + int Padding_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int w = bottom_blob.w; diff --git a/src/layer/x86/padding_x86.h b/src/layer/x86/padding_x86.h index a2428252f39..e2755333b76 100644 --- a/src/layer/x86/padding_x86.h +++ b/src/layer/x86/padding_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_PADDING_X86_H @@ -13,10 +13,23 @@ class Padding_x86 : public Padding public: Padding_x86(); + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; protected: + int forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + // bf16 + unsigned short value_bf16; + Mat per_channel_pad_data_bf16; + + // fp16 + unsigned short value_fp16; + Mat per_channel_pad_data_fp16; }; } // namespace ncnn diff --git a/src/layer/x86/reshape_x86.cpp b/src/layer/x86/reshape_x86.cpp index 3dd675ccf1e..985937e9c62 100644 --- a/src/layer/x86/reshape_x86.cpp +++ b/src/layer/x86/reshape_x86.cpp @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "reshape_x86.h" @@ -10,6 +10,7 @@ #endif #endif // __SSE2__ +#include "cpu.h" #include "x86_usability.h" namespace ncnn { @@ -19,6 +20,10 @@ Reshape_x86::Reshape_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Reshape_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const @@ -26,6 +31,11 @@ int Reshape_x86::forward(const std::vector& bottom_blobs, std::vector& const Mat& bottom_blob = bottom_blobs[0]; Mat& top_blob = top_blobs[0]; + int elembits = bottom_blob.elembits(); + + if (elembits == 16) + return forward_bf16s_fp16s(bottom_blobs, top_blobs, opt); + // resolve out shape int outw = w; int outh = h; @@ -685,4 +695,446 @@ int Reshape_x86::forward(const std::vector& bottom_blobs, std::vector& return 0; } +int Reshape_x86::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + Mat& top_blob = top_blobs[0]; + + // resolve out shape + int outw = w; + int outh = h; + int outd = d; + int outc = c; + + if (!shape_expr.empty()) + { + int er = eval_shape_expr(bottom_blobs, outw, outh, outd, outc); + if (er != 0) + return -1; + } + + if (ndim == 1) + { + // flatten + flatten(bottom_blob, top_blob, opt); + if (top_blob.empty()) + return -100; + + return 0; + } + + const int dims = bottom_blob.dims; + const int elempack = bottom_blob.elempack; + const size_t elemsize = bottom_blob.elemsize; + + const int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * elempack; + + if (ndim == 2) + { + if (outw == 0) + outw = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w; + if (outh == 0) + outh = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h; + + if (outw == -1) + outw = total / outh; + if (outh == -1) + outh = total / outw; + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = outh % 16 == 0 ? 16 : outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; +#else + out_elempack = outh % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (dims == 2 && bottom_blob.h * elempack == outh && elempack == out_elempack) + { + top_blob = bottom_blob; + return 0; + } + + if (out_elempack == 1) + { + // flatten + flatten(bottom_blob, top_blob, opt); + if (top_blob.empty()) + return -100; + + top_blob.dims = 2; + top_blob.w = outw; + top_blob.h = outh; + top_blob.cstep = top_blob.cstep * top_blob.elempack; + top_blob.elemsize = out_elemsize; + top_blob.elempack = out_elempack; + + return 0; + } + + // flatten + Mat bottom_blob_flattened = bottom_blob; + { + Option opt_flatten = opt; + opt_flatten.blob_allocator = opt.workspace_allocator; + + flatten(bottom_blob, bottom_blob_flattened, opt_flatten); + if (bottom_blob_flattened.empty()) + return -100; + } + + top_blob.create(outw, outh / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < top_blob.h; i++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + outw * i * 16; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 3); + const unsigned short* ptr4 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 4); + const unsigned short* ptr5 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 5); + const unsigned short* ptr6 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 6); + const unsigned short* ptr7 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 7); + const unsigned short* ptr8 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 8); + const unsigned short* ptr9 = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 9); + const unsigned short* ptra = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 10); + const unsigned short* ptrb = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 11); + const unsigned short* ptrc = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 12); + const unsigned short* ptrd = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 13); + const unsigned short* ptre = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 14); + const unsigned short* ptrf = (const unsigned short*)bottom_blob_flattened + outw * (i * 16 + 15); + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < outw; j++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + outptr[4] = *ptr4++; + outptr[5] = *ptr5++; + outptr[6] = *ptr6++; + outptr[7] = *ptr7++; + outptr[8] = *ptr8++; + outptr[9] = *ptr9++; + outptr[10] = *ptra++; + outptr[11] = *ptrb++; + outptr[12] = *ptrc++; + outptr[13] = *ptrd++; + outptr[14] = *ptre++; + outptr[15] = *ptrf++; + + outptr += 16; + } + } + } +#endif // __AVX512F__ + + if (out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < top_blob.h; i++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + outw * i * 8; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 3); + const unsigned short* ptr4 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 4); + const unsigned short* ptr5 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 5); + const unsigned short* ptr6 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 6); + const unsigned short* ptr7 = (const unsigned short*)bottom_blob_flattened + outw * (i * 8 + 7); + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < outw; j++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + outptr[4] = *ptr4++; + outptr[5] = *ptr5++; + outptr[6] = *ptr6++; + outptr[7] = *ptr7++; + + outptr += 8; + } + } + } +#endif // __AVX__ + + if (out_elempack == 4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < top_blob.h; i++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + outw * i * 4; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 3); + unsigned short* outptr = top_blob.row(i); + + for (int j = 0; j < outw; j++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + + outptr += 4; + } + } + } +#endif // __SSE2__ + } + + if (ndim == 3 || ndim == 4) + { + if (ndim == 3) + { + if (outw == 0) + outw = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w; + if (outh == 0) + outh = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h; + if (outc == 0) + outc = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c; + + if (outw == -1) + outw = total / outc / outh; + if (outh == -1) + outh = total / outc / outw; + if (outc == -1) + outc = total / outh / outw; + + outd = 1; + } + else // if (ndim == 4) + { + if (outw == 0) + outw = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w; + if (outh == 0) + outh = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h; + if (outd == 0) + outd = bottom_blob.d; + if (outc == 0) + outc = (dims == 3 || dims == 4) ? bottom_blob.c * elempack : bottom_blob.c; + + if (outw == -1) + outw = total / outc / outd / outh; + if (outh == -1) + outh = total / outc / outd / outw; + if (outd == -1) + outd = total / outc / outh / outw; + if (outc == -1) + outc = total / outd / outh / outw; + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = outc % 16 == 0 ? 16 : outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = outc % 8 == 0 ? 8 : outc % 4 == 0 ? 4 : 1; +#else + out_elempack = outc % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + if ((dims == 3 || dims == 4) && bottom_blob.c * elempack == outc && elempack == out_elempack) + { + top_blob = bottom_blob; + top_blob.dims = ndim; + top_blob.w = outw; + top_blob.h = outh; + top_blob.d = outd; + return 0; + } + + // flatten + Mat bottom_blob_flattened = bottom_blob; + { + Option opt_flatten = opt; + opt_flatten.blob_allocator = opt.workspace_allocator; + + flatten(bottom_blob, bottom_blob_flattened, opt_flatten); + if (bottom_blob_flattened.empty()) + return -100; + } + + if (ndim == 3) + { + top_blob.create(outw, outh, outc / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else // if (ndim == 4) + { + top_blob.create(outw, outh, outd, outc / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int size = top_blob.w * top_blob.h * top_blob.d; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + size * q * 16; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 3); + const unsigned short* ptr4 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 4); + const unsigned short* ptr5 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 5); + const unsigned short* ptr6 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 6); + const unsigned short* ptr7 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 7); + const unsigned short* ptr8 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 8); + const unsigned short* ptr9 = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 9); + const unsigned short* ptra = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 10); + const unsigned short* ptrb = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 11); + const unsigned short* ptrc = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 12); + const unsigned short* ptrd = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 13); + const unsigned short* ptre = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 14); + const unsigned short* ptrf = (const unsigned short*)bottom_blob_flattened + size * (q * 16 + 15); + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + outptr[4] = *ptr4++; + outptr[5] = *ptr5++; + outptr[6] = *ptr6++; + outptr[7] = *ptr7++; + outptr[8] = *ptr8++; + outptr[9] = *ptr9++; + outptr[10] = *ptra++; + outptr[11] = *ptrb++; + outptr[12] = *ptrc++; + outptr[13] = *ptrd++; + outptr[14] = *ptre++; + outptr[15] = *ptrf++; + + outptr += 16; + } + } + } +#endif // __AVX512F__ + + if (out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + size * q * 8; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 3); + const unsigned short* ptr4 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 4); + const unsigned short* ptr5 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 5); + const unsigned short* ptr6 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 6); + const unsigned short* ptr7 = (const unsigned short*)bottom_blob_flattened + size * (q * 8 + 7); + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + outptr[4] = *ptr4++; + outptr[5] = *ptr5++; + outptr[6] = *ptr6++; + outptr[7] = *ptr7++; + + outptr += 8; + } + } + } +#endif // __AVX__ + + if (out_elempack == 4) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + size * q * 4; + const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 1); + const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 2); + const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 3); + unsigned short* outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[0] = *ptr0++; + outptr[1] = *ptr1++; + outptr[2] = *ptr2++; + outptr[3] = *ptr3++; + + outptr += 4; + } + } + } +#endif // __SSE2__ + + if (out_elempack == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* ptr = (const unsigned short*)bottom_blob_flattened + size * q; + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ + for (; i + 15 < size; i += 16) + { + __m256i _v = _mm256_loadu_si256((const __m256i*)ptr); + _mm256_storeu_si256((__m256i*)outptr, _v); + ptr += 16; + outptr += 16; + } +#endif + for (; i + 7 < size; i += 8) + { + __m128i _v = _mm_loadu_si128((const __m128i*)ptr); + _mm_storeu_si128((__m128i*)outptr, _v); + ptr += 8; + outptr += 8; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = *ptr++; + } + } + } + } + + return 0; +} + } // namespace ncnn diff --git a/src/layer/x86/reshape_x86.h b/src/layer/x86/reshape_x86.h index 6b354202521..5bb9237de76 100644 --- a/src/layer/x86/reshape_x86.h +++ b/src/layer/x86/reshape_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_RESHAPE_X86_H @@ -14,6 +14,9 @@ class Reshape_x86 : public Reshape Reshape_x86(); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: + int forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; }; } // namespace ncnn diff --git a/src/layer/x86/slice_x86.cpp b/src/layer/x86/slice_x86.cpp index e0f895ecc85..a8b012d3e69 100644 --- a/src/layer/x86/slice_x86.cpp +++ b/src/layer/x86/slice_x86.cpp @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "slice_x86.h" @@ -10,6 +10,8 @@ #endif // __AVX__ #endif // __SSE2__ +#include "cpu.h" + namespace ncnn { Slice_x86::Slice_x86() @@ -17,6 +19,10 @@ Slice_x86::Slice_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ + support_fp16_storage = cpu_support_x86_f16c(); +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const @@ -29,6 +35,9 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; + if (bottom_blob.elembits() == 16) + return forward_bf16s_fp16s(bottom_blobs, top_blobs, opt); + if (dims == 1) // positive_axis == 0 { // slice vector @@ -962,4 +971,947 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t return 0; } +int Slice_x86::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + const int* slices_ptr = slices; + const int* indices_ptr = indices; + int positive_axis = axis < 0 ? dims + axis : axis; + + if (dims == 1) // positive_axis == 0 + { + // slice vector + int w = bottom_blob.w * elempack; + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = slice % 16 == 0 ? 16 : slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#else + out_elempack = slice % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[i]; + top_blob.create(slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const unsigned short* ptr = (const unsigned short*)bottom_blob + q; + unsigned short* outptr = top_blob; + memcpy(outptr, ptr, top_blob.w * top_blob.elemsize); + + q += slice; + } + } + + if (dims == 2 && positive_axis == 0) + { + // slice image height + int w = bottom_blob.w; + int h = bottom_blob.h * elempack; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = slice % 16 == 0 ? 16 : slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#else + out_elempack = slice % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[i]; + top_blob.create(w, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + q += slice; + } + + size_t out_elemsize = top_blobs[0].elemsize; + int out_elempack = top_blobs[0].elempack; + for (size_t i = 0; i < top_blobs.size(); i++) + { + out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize); + out_elempack = std::min(out_elempack, top_blobs[i].elempack); + } + + Mat bottom_blob_unpacked = bottom_blob; + if (elempack > out_elempack) + { + convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt); + if (bottom_blob_unpacked.empty()) + return -100; + } + + const unsigned short* ptr = bottom_blob_unpacked; + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 8 && top_blob.elempack == 16) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w * 8; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr0[8] = r1[0]; + outptr0[9] = r1[1]; + outptr0[10] = r1[2]; + outptr0[11] = r1[3]; + outptr0[12] = r1[4]; + outptr0[13] = r1[5]; + outptr0[14] = r1[6]; + outptr0[15] = r1[7]; + + r0 += 8; + r1 += 8; + outptr0 += 16; + } + + ptr += w * 16; + } + } + if (out_elempack == 4 && top_blob.elempack == 16) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w * 4; + const unsigned short* r2 = ptr + w * 8; + const unsigned short* r3 = ptr + w * 12; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r1[0]; + outptr0[5] = r1[1]; + outptr0[6] = r1[2]; + outptr0[7] = r1[3]; + outptr0[8] = r2[0]; + outptr0[9] = r2[1]; + outptr0[10] = r2[2]; + outptr0[11] = r2[3]; + outptr0[12] = r3[0]; + outptr0[13] = r3[1]; + outptr0[14] = r3[2]; + outptr0[15] = r3[3]; + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr0 += 16; + } + + ptr += w * 16; + } + } + if (out_elempack == 1 && top_blob.elempack == 16) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w; + const unsigned short* r2 = ptr + w * 2; + const unsigned short* r3 = ptr + w * 3; + const unsigned short* r4 = ptr + w * 4; + const unsigned short* r5 = ptr + w * 5; + const unsigned short* r6 = ptr + w * 6; + const unsigned short* r7 = ptr + w * 7; + const unsigned short* r8 = ptr + w * 8; + const unsigned short* r9 = ptr + w * 9; + const unsigned short* ra = ptr + w * 10; + const unsigned short* rb = ptr + w * 11; + const unsigned short* rc = ptr + w * 12; + const unsigned short* rd = ptr + w * 13; + const unsigned short* re = ptr + w * 14; + const unsigned short* rf = ptr + w * 15; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + outptr0[4] = *r4++; + outptr0[5] = *r5++; + outptr0[6] = *r6++; + outptr0[7] = *r7++; + outptr0[8] = *r8++; + outptr0[9] = *r9++; + outptr0[10] = *ra++; + outptr0[11] = *rb++; + outptr0[12] = *rc++; + outptr0[13] = *rd++; + outptr0[14] = *re++; + outptr0[15] = *rf++; + + outptr0 += 16; + } + + ptr += w * 16; + } + } +#endif // __AVX512F__ + if (out_elempack == 4 && top_blob.elempack == 8) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w * 4; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r1[0]; + outptr0[5] = r1[1]; + outptr0[6] = r1[2]; + outptr0[7] = r1[3]; + + r0 += 4; + r1 += 4; + outptr0 += 8; + } + + ptr += w * 8; + } + } + if (out_elempack == 1 && top_blob.elempack == 8) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w; + const unsigned short* r2 = ptr + w * 2; + const unsigned short* r3 = ptr + w * 3; + const unsigned short* r4 = ptr + w * 4; + const unsigned short* r5 = ptr + w * 5; + const unsigned short* r6 = ptr + w * 6; + const unsigned short* r7 = ptr + w * 7; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + outptr0[4] = *r4++; + outptr0[5] = *r5++; + outptr0[6] = *r6++; + outptr0[7] = *r7++; + + outptr0 += 8; + } + + ptr += w * 8; + } + } +#endif // __AVX__ + if (out_elempack == 1 && top_blob.elempack == 4) + { + for (int j = 0; j < top_blob.h; j++) + { + const unsigned short* r0 = ptr; + const unsigned short* r1 = ptr + w; + const unsigned short* r2 = ptr + w * 2; + const unsigned short* r3 = ptr + w * 3; + + unsigned short* outptr0 = top_blob.row(j); + + for (int j = 0; j < w; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + + outptr0 += 4; + } + + ptr += w * 4; + } + } +#endif // __SSE2__ + if (out_elempack == top_blob.elempack) + { + // 1-1 4-4 8-8 + int size = w * top_blob.h; + + unsigned short* outptr = top_blob; + memcpy(outptr, ptr, size * top_blob.elemsize); + + ptr += size * top_blob.elempack; + } + } + } + + if (dims == 2 && positive_axis == 1) + { + // slice image width + int w = bottom_blob.w; + int h = bottom_blob.h; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } + } + + Mat& top_blob = top_blobs[i]; + top_blob.create(slice, h, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + q += slice; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int j = 0; j < h; j++) + { + const unsigned short* ptr = bottom_blob.row(j); + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + + unsigned short* outptr = top_blob.row(j); + memcpy(outptr, ptr, top_blob.w * elemsize); + + ptr += top_blob.w * elempack; + } + } + } + + if ((dims == 3 || dims == 4) && positive_axis == 0) + { + // slice dim channel + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c * elempack; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = slice % 16 == 0 ? 16 : slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; +#else + out_elempack = slice % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = elemsize / elempack * out_elempack; + + Mat& top_blob = top_blobs[i]; + top_blob.create(w, h, d, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + q += slice; + } + + size_t out_elemsize = top_blobs[0].elemsize; + int out_elempack = top_blobs[0].elempack; + for (size_t i = 0; i < top_blobs.size(); i++) + { + out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize); + out_elempack = std::min(out_elempack, top_blobs[i].elempack); + } + + Mat bottom_blob_unpacked = bottom_blob; + if (elempack > out_elempack) + { + convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt); + if (bottom_blob_unpacked.empty()) + return -100; + } + + int p = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 8 && top_blob.elempack == 16) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r0[4]; + outptr0[5] = r0[5]; + outptr0[6] = r0[6]; + outptr0[7] = r0[7]; + outptr0[8] = r1[0]; + outptr0[9] = r1[1]; + outptr0[10] = r1[2]; + outptr0[11] = r1[3]; + outptr0[12] = r1[4]; + outptr0[13] = r1[5]; + outptr0[14] = r1[6]; + outptr0[15] = r1[7]; + + r0 += 8; + r1 += 8; + outptr0 += 16; + } + + p += 2; + } + } + if (out_elempack == 4 && top_blob.elempack == 16) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + const unsigned short* r2 = (const unsigned short*)bottom_blob_unpacked.channel(p + 2); + const unsigned short* r3 = (const unsigned short*)bottom_blob_unpacked.channel(p + 3); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r1[0]; + outptr0[5] = r1[1]; + outptr0[6] = r1[2]; + outptr0[7] = r1[3]; + outptr0[8] = r2[0]; + outptr0[9] = r2[1]; + outptr0[10] = r2[2]; + outptr0[11] = r2[3]; + outptr0[12] = r3[0]; + outptr0[13] = r3[1]; + outptr0[14] = r3[2]; + outptr0[15] = r3[3]; + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr0 += 16; + } + + p += 4; + } + } + if (out_elempack == 1 && top_blob.elempack == 16) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + const unsigned short* r2 = (const unsigned short*)bottom_blob_unpacked.channel(p + 2); + const unsigned short* r3 = (const unsigned short*)bottom_blob_unpacked.channel(p + 3); + const unsigned short* r4 = (const unsigned short*)bottom_blob_unpacked.channel(p + 4); + const unsigned short* r5 = (const unsigned short*)bottom_blob_unpacked.channel(p + 5); + const unsigned short* r6 = (const unsigned short*)bottom_blob_unpacked.channel(p + 6); + const unsigned short* r7 = (const unsigned short*)bottom_blob_unpacked.channel(p + 7); + const unsigned short* r8 = (const unsigned short*)bottom_blob_unpacked.channel(p + 8); + const unsigned short* r9 = (const unsigned short*)bottom_blob_unpacked.channel(p + 9); + const unsigned short* ra = (const unsigned short*)bottom_blob_unpacked.channel(p + 10); + const unsigned short* rb = (const unsigned short*)bottom_blob_unpacked.channel(p + 11); + const unsigned short* rc = (const unsigned short*)bottom_blob_unpacked.channel(p + 12); + const unsigned short* rd = (const unsigned short*)bottom_blob_unpacked.channel(p + 13); + const unsigned short* re = (const unsigned short*)bottom_blob_unpacked.channel(p + 14); + const unsigned short* rf = (const unsigned short*)bottom_blob_unpacked.channel(p + 15); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + outptr0[4] = *r4++; + outptr0[5] = *r5++; + outptr0[6] = *r6++; + outptr0[7] = *r7++; + outptr0[8] = *r8++; + outptr0[9] = *r9++; + outptr0[10] = *ra++; + outptr0[11] = *rb++; + outptr0[12] = *rc++; + outptr0[13] = *rd++; + outptr0[14] = *re++; + outptr0[15] = *rf++; + + outptr0 += 16; + } + + p += 16; + } + } +#endif // __AVX512F__ + if (out_elempack == 4 && top_blob.elempack == 8) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = r0[0]; + outptr0[1] = r0[1]; + outptr0[2] = r0[2]; + outptr0[3] = r0[3]; + outptr0[4] = r1[0]; + outptr0[5] = r1[1]; + outptr0[6] = r1[2]; + outptr0[7] = r1[3]; + + r0 += 4; + r1 += 4; + outptr0 += 8; + } + + p += 2; + } + } + if (out_elempack == 1 && top_blob.elempack == 8) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + const unsigned short* r2 = (const unsigned short*)bottom_blob_unpacked.channel(p + 2); + const unsigned short* r3 = (const unsigned short*)bottom_blob_unpacked.channel(p + 3); + const unsigned short* r4 = (const unsigned short*)bottom_blob_unpacked.channel(p + 4); + const unsigned short* r5 = (const unsigned short*)bottom_blob_unpacked.channel(p + 5); + const unsigned short* r6 = (const unsigned short*)bottom_blob_unpacked.channel(p + 6); + const unsigned short* r7 = (const unsigned short*)bottom_blob_unpacked.channel(p + 7); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + outptr0[4] = *r4++; + outptr0[5] = *r5++; + outptr0[6] = *r6++; + outptr0[7] = *r7++; + + outptr0 += 8; + } + + p += 8; + } + } +#endif // __AVX__ + if (out_elempack == 1 && top_blob.elempack == 4) + { + int size = top_blob.w * top_blob.h * top_blob.d; + + for (int q = 0; q < top_blob.c; q++) + { + const unsigned short* r0 = (const unsigned short*)bottom_blob_unpacked.channel(p); + const unsigned short* r1 = (const unsigned short*)bottom_blob_unpacked.channel(p + 1); + const unsigned short* r2 = (const unsigned short*)bottom_blob_unpacked.channel(p + 2); + const unsigned short* r3 = (const unsigned short*)bottom_blob_unpacked.channel(p + 3); + + unsigned short* outptr0 = top_blob.channel(q); + + for (int j = 0; j < size; j++) + { + outptr0[0] = *r0++; + outptr0[1] = *r1++; + outptr0[2] = *r2++; + outptr0[3] = *r3++; + + outptr0 += 4; + } + + p += 4; + } + } +#endif // __SSE2__ + if (out_elempack == top_blob.elempack) + { + // 1-1 4-4 8-8 + int size = top_blob.total(); + + const unsigned short* ptr = (const unsigned short*)bottom_blob_unpacked.channel(p); + unsigned short* outptr = top_blob; + memcpy(outptr, ptr, size * top_blob.elemsize); + + p += top_blob.c; + } + } + } + + if ((dims == 3 && positive_axis == 1) || (dims == 4 && positive_axis == 2)) + { + // slice dim height + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } + } + + Mat& top_blob = top_blobs[i]; + top_blob.create(w, slice, d, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + q += slice; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const unsigned short* ptr = (const unsigned short*)bottom_blob.channel(p); + + for (int j = 0; j < d; j++) + { + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + + int size = top_blob.w * top_blob.h; + + unsigned short* outptr = top_blob.channel(p).depth(j); + memcpy(outptr, ptr, size * elemsize); + + ptr += size * elempack; + } + } + } + } + + if ((dims == 3 && positive_axis == 2) || (dims == 4 && positive_axis == 3)) + { + // slice dim width + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } + } + + Mat& top_blob = top_blobs[i]; + top_blob.create(slice, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + top_blob.dims = dims; + + q += slice; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const unsigned short* ptr = (const unsigned short*)bottom_blob.channel(p); + + for (int j = 0; j < d; j++) + { + for (int k = 0; k < h; k++) + { + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + + unsigned short* outptr = top_blob.channel(p).depth(j).row(k); + memcpy(outptr, ptr, top_blob.w * elemsize); + + ptr += top_blob.w * elempack; + } + } + } + } + } + + if (dims == 4 && positive_axis == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + + int q = 0; + for (size_t i = 0; i < top_blobs.size(); i++) + { + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } + } + + Mat& top_blob = top_blobs[i]; + top_blob.create(w, h, slice, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + q += slice; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const unsigned short* ptr = (const unsigned short*)bottom_blob.channel(p); + + for (size_t i = 0; i < top_blobs.size(); i++) + { + Mat& top_blob = top_blobs[i]; + + int size = top_blob.w * top_blob.h * top_blob.d; + + unsigned short* outptr = top_blob.channel(p); + memcpy(outptr, ptr, size * elemsize); + + ptr += size * elempack; + } + } + } + + return 0; +} + } // namespace ncnn diff --git a/src/layer/x86/slice_x86.h b/src/layer/x86/slice_x86.h index 06b98f6d31a..2ee83bfb91c 100644 --- a/src/layer/x86/slice_x86.h +++ b/src/layer/x86/slice_x86.h @@ -1,4 +1,4 @@ -// Copyright 2019 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_SLICE_X86_H @@ -14,6 +14,9 @@ class Slice_x86 : public Slice Slice_x86(); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: + int forward_bf16s_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; }; } // namespace ncnn From 3da4ee3441f982451ecceded1ff9b6eb5adc08bc Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Mar 2026 19:59:07 +0800 Subject: [PATCH 17/36] x86 groupnorm instancenorm support bf16 storage with avx512bf16 dispatch (#6594) --- src/layer/x86/groupnorm_bf16s.h | 329 ++++++++++++++++++ src/layer/x86/groupnorm_x86.cpp | 94 +++++ src/layer/x86/groupnorm_x86.h | 5 + src/layer/x86/groupnorm_x86_avx512bf16.cpp | 17 + src/layer/x86/instancenorm_bf16s.h | 158 +++++++++ src/layer/x86/instancenorm_x86.cpp | 60 ++++ src/layer/x86/instancenorm_x86.h | 7 + src/layer/x86/instancenorm_x86_avx512bf16.cpp | 22 ++ 8 files changed, 692 insertions(+) create mode 100644 src/layer/x86/groupnorm_bf16s.h create mode 100644 src/layer/x86/groupnorm_x86_avx512bf16.cpp create mode 100644 src/layer/x86/instancenorm_bf16s.h create mode 100644 src/layer/x86/instancenorm_x86_avx512bf16.cpp diff --git a/src/layer/x86/groupnorm_bf16s.h b/src/layer/x86/groupnorm_bf16s.h new file mode 100644 index 00000000000..e4e57e440c3 --- /dev/null +++ b/src/layer/x86/groupnorm_bf16s.h @@ -0,0 +1,329 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void groupnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int channels, int size, int elempack, size_t cstep); +#endif + +static void groupnorm_bf16s_sse(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int channels, int size, int elempack, size_t cstep) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + groupnorm_bf16s_sse_avx512bf16(ptr, gamma_ptr, beta_ptr, eps, channels, size, elempack, cstep); + return; + } +#endif + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _mean_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _mean_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _mean = _mm_set1_ps(0.f); +#endif // __SSE2__ + float mean = 0.f; + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr0 = ptr + cstep * q * elempack; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _mean_avx512 = _mm512_add_ps(_mean_avx512, _p); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _mean_avx = _mm256_add_ps(_mean_avx, _p); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _mean = _mm_add_ps(_mean, _p); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + mean += bfloat16_to_float32(*ptr0); + ptr0++; + } + } + + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + mean += _mm512_comp_reduce_add_ps(_mean_avx512); +#endif // __AVX512F__ + mean += _mm256_reduce_add_ps(_mean_avx); +#endif // __AVX__ + mean += _mm_reduce_add_ps(_mean); +#endif // __SSE2__ + + mean = mean / (channels * size); +#if __SSE2__ + _mean = _mm_set1_ps(mean); +#if __AVX__ + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _var_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _var_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _var = _mm_set1_ps(0.f); +#endif // __SSE2__ + float var = 0.f; + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr0 = ptr + cstep * q * elempack; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _p = _mm512_sub_ps(_p, _mean_avx512); + _var_avx512 = _mm512_fmadd_ps(_p, _p, _var_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _p = _mm256_sub_ps(_p, _mean_avx); + _var_avx = _mm256_comp_fmadd_ps(_p, _p, _var_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _p = _mm_sub_ps(_p, _mean); + _var = _mm_comp_fmadd_ps(_p, _p, _var); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr0) - mean; + var += v * v; + ptr0++; + } + } + + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + var += _mm512_comp_reduce_add_ps(_var_avx512); +#endif // __AVX512F__ + var += _mm256_reduce_add_ps(_var_avx); +#endif // __AVX__ + var += _mm_reduce_add_ps(_var); +#endif // __SSE2__ + + var = 1.f / sqrtf(var / (channels * size) + eps); + mean = mean * var; +#if __SSE2__ + _var = _mm_set1_ps(var); + _mean = _mm_set1_ps(mean); +#if __AVX__ + _var_avx = combine4x2_ps(_var, _var); + _mean_avx = combine4x2_ps(_mean, _mean); +#if __AVX512F__ + _var_avx512 = combine8x2_ps(_var_avx, _var_avx); + _mean_avx512 = combine8x2_ps(_mean_avx, _mean_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + // v = v * var - mean; + // v = (v * var - mean) * gamma + beta + // = v * var * gamma - mean * gamma + beta + // = v * (var * gamma) - (mean * gamma - beta) + + if (gamma_ptr && beta_ptr) + { + for (int q = 0; q < channels; q++) + { + unsigned short* ptr0 = ptr + cstep * q * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _a_avx512 = _mm512_set1_ps(0.f); + __m512 _b_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _a_avx = _mm256_set1_ps(0.f); + __m256 _b_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _a = _mm_set1_ps(0.f); + __m128 _b = _mm_set1_ps(0.f); +#endif // __SSE2__ + float a = 0.f; + float b = 0.f; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _gamma = _mm512_loadu_ps(gamma_ptr + q * elempack); + __m512 _beta = _mm512_loadu_ps(beta_ptr + q * elempack); + + _a_avx512 = _mm512_mul_ps(_var_avx512, _gamma); + _b_avx512 = _mm512_fmsub_ps(_mean_avx512, _gamma, _beta); + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _gamma = _mm256_loadu_ps(gamma_ptr + q * elempack); + __m256 _beta = _mm256_loadu_ps(beta_ptr + q * elempack); + + _a_avx = _mm256_mul_ps(_var_avx, _gamma); + _b_avx = _mm256_comp_fmsub_ps(_mean_avx, _gamma, _beta); +#if __AVX512F__ + _a_avx512 = combine8x2_ps(_a_avx, _a_avx); + _b_avx512 = combine8x2_ps(_b_avx, _b_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _gamma = _mm_loadu_ps(gamma_ptr + q * elempack); + __m128 _beta = _mm_loadu_ps(beta_ptr + q * elempack); + + _a = _mm_mul_ps(_var, _gamma); + _b = _mm_comp_fmsub_ps(_mean, _gamma, _beta); +#if __AVX__ + _a_avx = combine4x2_ps(_a, _a); + _b_avx = combine4x2_ps(_b, _b); +#if __AVX512F__ + _a_avx512 = combine8x2_ps(_a_avx, _a_avx); + _b_avx512 = combine8x2_ps(_b_avx, _b_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { + const float gamma = gamma_ptr[q]; + const float beta = beta_ptr[q]; + + a = var * gamma; + b = mean * gamma - beta; +#if __SSE2__ + _a = _mm_set1_ps(a); + _b = _mm_set1_ps(b); +#if __AVX__ + _a_avx = combine4x2_ps(_a, _a); + _b_avx = combine4x2_ps(_b, _b); +#if __AVX512F__ + _a_avx512 = combine8x2_ps(_a_avx, _a_avx); + _b_avx512 = combine8x2_ps(_b_avx, _b_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _p = _mm512_fmsub_ps(_p, _a_avx512, _b_avx512); + _mm256_storeu_si256((__m256i*)ptr0, float2bfloat_avx512(_p)); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _p = _mm256_comp_fmsub_ps(_p, _a_avx, _b_avx); + _mm_storeu_si128((__m128i*)ptr0, float2bfloat_avx(_p)); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _p = _mm_comp_fmsub_ps(_p, _a, _b); + _mm_storel_epi64((__m128i*)ptr0, float2bfloat_sse(_p, _p)); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr0 = float32_to_bfloat16(bfloat16_to_float32(*ptr0) * a - b); + ptr0++; + } + } + } + else + { + for (int q = 0; q < channels; q++) + { + unsigned short* ptr0 = ptr + cstep * q * elempack; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _mm256_storeu_si256((__m256i*)ptr0, float2bfloat_avx512(_p)); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _mm_storeu_si128((__m128i*)ptr0, float2bfloat_avx(_p)); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _mm_storel_epi64((__m128i*)ptr0, float2bfloat_sse(_p, _p)); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr0 = float32_to_bfloat16(bfloat16_to_float32(*ptr0) * var - mean); + ptr0++; + } + } + } +} diff --git a/src/layer/x86/groupnorm_x86.cpp b/src/layer/x86/groupnorm_x86.cpp index 29333c20ec0..ccd81747e11 100644 --- a/src/layer/x86/groupnorm_x86.cpp +++ b/src/layer/x86/groupnorm_x86.cpp @@ -11,14 +11,22 @@ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "groupnorm_bf16s.h" +#endif + GroupNorm_x86::GroupNorm_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void groupnorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int channels, int size, int elempack, size_t cstep) @@ -338,6 +346,11 @@ static void groupnorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, int GroupNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + const int dims = bottom_top_blob.dims; const int elempack = bottom_top_blob.elempack; const int channels_g = channels / group; @@ -415,4 +428,85 @@ int GroupNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons return 0; } +#if NCNN_BF16 +int GroupNorm_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int elempack = bottom_top_blob.elempack; + const int channels_g = channels / group; + + int g_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + g_elempack = channels_g % 16 == 0 ? 16 : channels_g % 8 == 0 ? 8 : channels_g % 4 == 0 ? 4 : 1; +#elif __AVX__ + g_elempack = channels_g % 8 == 0 ? 8 : channels_g % 4 == 0 ? 4 : 1; +#else + g_elempack = channels_g % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + + Mat bottom_top_blob_unpacked = bottom_top_blob; + if (elempack > g_elempack) + { + Option opt_p = opt; + opt_p.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_top_blob, bottom_top_blob_unpacked, g_elempack, opt_p); + if (bottom_top_blob_unpacked.empty()) + return -100; + } + + if (dims == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat bottom_top_blob_g = bottom_top_blob_unpacked.range(g * channels_g / g_elempack, channels_g / g_elempack); + const float* gamma_ptr = affine ? (const float*)gamma_data + g * channels_g : 0; + const float* beta_ptr = affine ? (const float*)beta_data + g * channels_g : 0; + groupnorm_bf16s_sse(bottom_top_blob_g, gamma_ptr, beta_ptr, eps, channels_g / g_elempack, 1 * g_elempack, g_elempack, 1); + } + } + + if (dims == 2) + { + const int w = bottom_top_blob_unpacked.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat bottom_top_blob_g = bottom_top_blob_unpacked.row_range(g * channels_g / g_elempack, channels_g / g_elempack); + const float* gamma_ptr = affine ? (const float*)gamma_data + g * channels_g : 0; + const float* beta_ptr = affine ? (const float*)beta_data + g * channels_g : 0; + groupnorm_bf16s_sse(bottom_top_blob_g, gamma_ptr, beta_ptr, eps, channels_g / g_elempack, w * g_elempack, g_elempack, w); + } + } + + if (dims == 3 || dims == 4) + { + const int size = bottom_top_blob_unpacked.w * bottom_top_blob_unpacked.h * bottom_top_blob_unpacked.d; + const size_t cstep = bottom_top_blob_unpacked.cstep; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat bottom_top_blob_g = bottom_top_blob_unpacked.channel_range(g * channels_g / g_elempack, channels_g / g_elempack); + const float* gamma_ptr = affine ? (const float*)gamma_data + g * channels_g : 0; + const float* beta_ptr = affine ? (const float*)beta_data + g * channels_g : 0; + groupnorm_bf16s_sse(bottom_top_blob_g, gamma_ptr, beta_ptr, eps, channels_g / g_elempack, size * g_elempack, g_elempack, cstep); + } + } + + if (g_elempack != elempack) + { + convert_packing(bottom_top_blob_unpacked, bottom_top_blob, elempack, opt); + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/groupnorm_x86.h b/src/layer/x86/groupnorm_x86.h index 42adbc4a0e2..dd3a7c7b3b2 100644 --- a/src/layer/x86/groupnorm_x86.h +++ b/src/layer/x86/groupnorm_x86.h @@ -14,6 +14,11 @@ class GroupNorm_x86 : public GroupNorm GroupNorm_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/groupnorm_x86_avx512bf16.cpp b/src/layer/x86/groupnorm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..5b6ee44dbbe --- /dev/null +++ b/src/layer/x86/groupnorm_x86_avx512bf16.cpp @@ -0,0 +1,17 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "groupnorm_bf16s.h" + +void groupnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int channels, int size, int elempack, size_t cstep) +{ + groupnorm_bf16s_sse(ptr, gamma_ptr, beta_ptr, eps, channels, size, elempack, cstep); +} + +} // namespace ncnn diff --git a/src/layer/x86/instancenorm_bf16s.h b/src/layer/x86/instancenorm_bf16s.h new file mode 100644 index 00000000000..ee450c99923 --- /dev/null +++ b/src/layer/x86/instancenorm_bf16s.h @@ -0,0 +1,158 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void instancenorm_bf16s_sse_avx512bf16(unsigned short* ptr, int size, float a, float b); +void instancenorm_bf16s_compute_mean_var_avx512bf16(const unsigned short* ptr, int size, float& mean, float& var); +#endif + +static void instancenorm_bf16s_sse(unsigned short* ptr, int size, float a, float b) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + instancenorm_bf16s_sse_avx512bf16(ptr, size, a, b); + return; + } +#endif + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _a_avx512 = _mm512_set1_ps(a); + __m512 _b_avx512 = _mm512_set1_ps(b); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmadd_ps(_p, _a_avx512, _b_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + __m256 _a_avx = _mm256_set1_ps(a); + __m256 _b_avx = _mm256_set1_ps(b); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmadd_ps(_p, _a_avx, _b_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + __m128 _a = _mm_set1_ps(a); + __m128 _b = _mm_set1_ps(b); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmadd_ps(_p, _a, _b); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * a + b); + ptr++; + } +} + +static void instancenorm_bf16s_compute_mean_var(const unsigned short* ptr, int size, float& mean, float& var) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + instancenorm_bf16s_compute_mean_var_avx512bf16(ptr, size, mean, var); + return; + } +#endif + + float sum = 0.f; + float sqsum = 0.f; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum_avx512 = _mm512_setzero_ps(); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _sum_avx512 = _mm512_add_ps(_sum_avx512, _p); + ptr += 16; + } + sum += _mm512_comp_reduce_add_ps(_sum_avx512); +#endif // __AVX512F__ + __m256 _sum_avx = _mm256_setzero_ps(); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _sum_avx = _mm256_add_ps(_sum_avx, _p); + ptr += 8; + } + sum += _mm256_reduce_add_ps(_sum_avx); +#endif // __AVX__ + __m128 _sum = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _sum = _mm_add_ps(_sum, _p); + ptr += 4; + } + sum += _mm_reduce_add_ps(_sum); +#endif // __SSE2__ + for (; i < size; i++) + { + sum += bfloat16_to_float32(*ptr); + ptr++; + } + + mean = sum / size; + + ptr -= size; + i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sqsum_avx512 = _mm512_setzero_ps(); + __m512 _mean_avx512 = _mm512_set1_ps(mean); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _diff = _mm512_sub_ps(_p, _mean_avx512); + _sqsum_avx512 = _mm512_fmadd_ps(_diff, _diff, _sqsum_avx512); + ptr += 16; + } + sqsum += _mm512_comp_reduce_add_ps(_sqsum_avx512); +#endif // __AVX512F__ + __m256 _sqsum_avx = _mm256_setzero_ps(); + __m256 _mean_avx = _mm256_set1_ps(mean); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _diff = _mm256_sub_ps(_p, _mean_avx); + _sqsum_avx = _mm256_comp_fmadd_ps(_diff, _diff, _sqsum_avx); + ptr += 8; + } + sqsum += _mm256_reduce_add_ps(_sqsum_avx); +#endif // __AVX__ + __m128 _sqsum = _mm_setzero_ps(); + __m128 _mean_sse = _mm_set1_ps(mean); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _diff = _mm_sub_ps(_p, _mean_sse); + _sqsum = _mm_comp_fmadd_ps(_diff, _diff, _sqsum); + ptr += 4; + } + sqsum += _mm_reduce_add_ps(_sqsum); +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr) - mean; + sqsum += v * v; + ptr++; + } + + var = sqsum / size; +} diff --git a/src/layer/x86/instancenorm_x86.cpp b/src/layer/x86/instancenorm_x86.cpp index 22b80ddc73d..01b2b87369a 100644 --- a/src/layer/x86/instancenorm_x86.cpp +++ b/src/layer/x86/instancenorm_x86.cpp @@ -10,11 +10,28 @@ #endif #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "instancenorm_bf16s.h" +#endif + +InstanceNorm_x86::InstanceNorm_x86() +{ +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + int InstanceNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + // x = (x - mean) / (sqrt(var + eps)) * gamma + beta int w = bottom_top_blob.w; @@ -179,4 +196,47 @@ int InstanceNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) c return 0; } +#if NCNN_BF16 +int InstanceNorm_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + // x = (x - mean) / (sqrt(var + eps)) * gamma + beta + + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + // compute mean and var + float mean = 0.f; + float var = 0.f; + instancenorm_bf16s_compute_mean_var(ptr, size, mean, var); + + float a; + float b; + if (affine) + { + float gamma = gamma_data[q]; + float beta = beta_data[q]; + + a = gamma / (sqrtf(var + eps)); + b = -mean * a + beta; + } + else + { + a = 1.f / (sqrtf(var + eps)); + b = -mean * a; + } + + instancenorm_bf16s_sse(ptr, size, a, b); + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/instancenorm_x86.h b/src/layer/x86/instancenorm_x86.h index 1f3e0d3338c..658b93e53a9 100644 --- a/src/layer/x86/instancenorm_x86.h +++ b/src/layer/x86/instancenorm_x86.h @@ -11,7 +11,14 @@ namespace ncnn { class InstanceNorm_x86 : public InstanceNorm { public: + InstanceNorm_x86(); + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/instancenorm_x86_avx512bf16.cpp b/src/layer/x86/instancenorm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..7ec0257e43f --- /dev/null +++ b/src/layer/x86/instancenorm_x86_avx512bf16.cpp @@ -0,0 +1,22 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "instancenorm_bf16s.h" + +void instancenorm_bf16s_sse_avx512bf16(unsigned short* ptr, int size, float a, float b) +{ + instancenorm_bf16s_sse(ptr, size, a, b); +} + +void instancenorm_bf16s_compute_mean_var_avx512bf16(const unsigned short* ptr, int size, float& mean, float& var) +{ + instancenorm_bf16s_compute_mean_var(ptr, size, mean, var); +} + +} // namespace ncnn From 1d917334de0a87a55b6d12248418222a2e8784e3 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Mar 2026 12:37:40 +0800 Subject: [PATCH 18/36] x86 batchnorm prelu scale swish softmax support bf16 storage with avx512bf16 dispatch (#6595) --- src/layer/x86/batchnorm_bf16s.h | 129 +++ src/layer/x86/batchnorm_x86.cpp | 135 ++- src/layer/x86/batchnorm_x86.h | 5 + src/layer/x86/batchnorm_x86_avx512bf16.cpp | 22 + src/layer/x86/prelu_bf16s.h | 213 +++++ src/layer/x86/prelu_x86.cpp | 76 ++ src/layer/x86/prelu_x86.h | 5 + src/layer/x86/prelu_x86_avx512bf16.cpp | 27 + src/layer/x86/scale_bf16s.h | 244 ++++++ src/layer/x86/scale_x86.cpp | 113 +++ src/layer/x86/scale_x86.h | 5 + src/layer/x86/scale_x86_avx512bf16.cpp | 32 + src/layer/x86/softmax_bf16s.h | 912 +++++++++++++++++++++ src/layer/x86/softmax_x86.cpp | 204 ++++- src/layer/x86/softmax_x86.h | 5 + src/layer/x86/softmax_x86_avx512bf16.cpp | 54 ++ src/layer/x86/swish_bf16s.h | 72 ++ src/layer/x86/swish_x86.cpp | 25 + src/layer/x86/swish_x86.h | 5 + src/layer/x86/swish_x86_avx512bf16.cpp | 32 + src/layer/x86/x86_usability.h | 24 +- 21 files changed, 2277 insertions(+), 62 deletions(-) create mode 100644 src/layer/x86/batchnorm_bf16s.h create mode 100644 src/layer/x86/batchnorm_x86_avx512bf16.cpp create mode 100644 src/layer/x86/prelu_bf16s.h create mode 100644 src/layer/x86/prelu_x86_avx512bf16.cpp create mode 100644 src/layer/x86/scale_bf16s.h create mode 100644 src/layer/x86/scale_x86_avx512bf16.cpp create mode 100644 src/layer/x86/softmax_bf16s.h create mode 100644 src/layer/x86/softmax_x86_avx512bf16.cpp create mode 100644 src/layer/x86/swish_bf16s.h create mode 100644 src/layer/x86/swish_x86_avx512bf16.cpp diff --git a/src/layer/x86/batchnorm_bf16s.h b/src/layer/x86/batchnorm_bf16s.h new file mode 100644 index 00000000000..0022cea5e49 --- /dev/null +++ b/src/layer/x86/batchnorm_bf16s.h @@ -0,0 +1,129 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void batchnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* a, const float* b, int size, int elempack); +void batchnorm_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* a, const float* b, int size, int num_threads); +#endif + +static void batchnorm_bf16s_sse(unsigned short* ptr, const float* a, const float* b, int size, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + batchnorm_bf16s_sse_avx512bf16(ptr, a, b, size, elempack); + return; + } +#endif + + // Load a/b into SIMD registers with correct elempack broadcasting +#if __SSE2__ + __m128 _a128 = (elempack == 4) ? _mm_loadu_ps(a) : _mm_set1_ps(a[0]); + __m128 _b128 = (elempack == 4) ? _mm_loadu_ps(b) : _mm_set1_ps(b[0]); +#if __AVX__ + __m256 _a256 = (elempack == 8) ? _mm256_loadu_ps(a) : combine4x2_ps(_a128, _a128); + __m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(b) : combine4x2_ps(_b128, _b128); +#if __AVX512F__ + __m512 _a512 = (elempack == 16) ? _mm512_loadu_ps(a) : combine8x2_ps(_a256, _a256); + __m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(b) : combine8x2_ps(_b256, _b256); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + float sa = a[0]; + float sb = b[0]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmadd_ps(_p, _b512, _a512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmadd_ps(_p, _b256, _a256); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmadd_ps(_p, _b128, _a128); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(sb * bfloat16_to_float32(*ptr) + sa); + ptr++; + } +} + +static void batchnorm_bf16s_per_element_sse(unsigned short* ptr, const float* a, const float* b, int size, int num_threads) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + batchnorm_bf16s_per_element_sse_avx512bf16(ptr, a, b, size, num_threads); + return; + } +#endif + + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + i))); + __m512 _a = _mm512_loadu_ps(a + i); + __m512 _b = _mm512_loadu_ps(b + i); + _p = _mm512_fmadd_ps(_p, _b, _a); + _mm256_storeu_si256((__m256i*)(ptr + i), float2bfloat_avx512(_p)); + } + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + i))); + __m256 _a = _mm256_loadu_ps(a + i); + __m256 _b = _mm256_loadu_ps(b + i); + _p = _mm256_comp_fmadd_ps(_p, _b, _a); + _mm_storeu_si128((__m128i*)(ptr + i), float2bfloat_avx(_p)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + i))); + __m128 _a = _mm_loadu_ps(a + i); + __m128 _b = _mm_loadu_ps(b + i); + _p = _mm_comp_fmadd_ps(_p, _b, _a); + _mm_storel_epi64((__m128i*)(ptr + i), float2bfloat_sse(_p, _p)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(num_threads) + for (int i = remain_size_start; i < size; i++) + { + ptr[i] = float32_to_bfloat16(b[i] * bfloat16_to_float32(ptr[i]) + a[i]); + } +} diff --git a/src/layer/x86/batchnorm_x86.cpp b/src/layer/x86/batchnorm_x86.cpp index dbfd678c7b4..d574d3a5336 100644 --- a/src/layer/x86/batchnorm_x86.cpp +++ b/src/layer/x86/batchnorm_x86.cpp @@ -10,18 +10,31 @@ #endif // __AVX__ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "batchnorm_bf16s.h" +#endif + BatchNorm_x86::BatchNorm_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int BatchNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + int dims = bottom_top_blob.dims; int w = bottom_top_blob.w; int h = bottom_top_blob.h; @@ -37,52 +50,51 @@ int BatchNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons const int size = w * elempack; - int i = 0; + int nn_size = 0; + int remain_size_start = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - for (; i + 15 < size; i += 16) + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - __m512 _p512 = _mm512_loadu_ps(ptr); - __m512 _a512 = _mm512_loadu_ps(aptr); - __m512 _b512 = _mm512_loadu_ps(bptr); - _p512 = _mm512_fmadd_ps(_p512, _b512, _a512); - _mm512_storeu_ps(ptr, _p512); - ptr += 16; - aptr += 16; - bptr += 16; + int i = remain_size_start + ii * 16; + __m512 _p512 = _mm512_loadu_ps(ptr + i); + __m512 _a512 = _mm512_loadu_ps(aptr + i); + __m512 _b512 = _mm512_loadu_ps(bptr + i); + _mm512_storeu_ps(ptr + i, _mm512_fmadd_ps(_p512, _b512, _a512)); } + remain_size_start += nn_size * 16; #endif // __AVX512F__ - for (; i + 7 < size; i += 8) + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - __m256 _p256 = _mm256_loadu_ps(ptr); - __m256 _a256 = _mm256_loadu_ps(aptr); - __m256 _b256 = _mm256_loadu_ps(bptr); - _p256 = _mm256_comp_fmadd_ps(_p256, _b256, _a256); - _mm256_storeu_ps(ptr, _p256); - ptr += 8; - aptr += 8; - bptr += 8; + int i = remain_size_start + ii * 8; + __m256 _p256 = _mm256_loadu_ps(ptr + i); + __m256 _a256 = _mm256_loadu_ps(aptr + i); + __m256 _b256 = _mm256_loadu_ps(bptr + i); + _mm256_storeu_ps(ptr + i, _mm256_comp_fmadd_ps(_p256, _b256, _a256)); } + remain_size_start += nn_size * 8; #endif // __AVX__ - for (; i + 3 < size; i += 4) + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) { - __m128 _p128 = _mm_loadu_ps(ptr); - __m128 _a128 = _mm_loadu_ps(aptr); - __m128 _b128 = _mm_loadu_ps(bptr); - _p128 = _mm_comp_fmadd_ps(_p128, _b128, _a128); - _mm_storeu_ps(ptr, _p128); - ptr += 4; - aptr += 4; - bptr += 4; + int i = remain_size_start + ii * 4; + __m128 _p128 = _mm_loadu_ps(ptr + i); + __m128 _a128 = _mm_loadu_ps(aptr + i); + __m128 _b128 = _mm_loadu_ps(bptr + i); + _mm_storeu_ps(ptr + i, _mm_comp_fmadd_ps(_p128, _b128, _a128)); } -#endif // __SSE__ - for (; i < size; i++) + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) { - *ptr = *bptr * *ptr + *aptr; - ptr++; - aptr++; - bptr++; + ptr[i] = bptr[i] * ptr[i] + aptr[i]; } } @@ -209,4 +221,59 @@ int BatchNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons return 0; } +#if NCNN_BF16 +int BatchNorm_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int c = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + unsigned short* ptr = bottom_top_blob; + const float* aptr = a_data; + const float* bptr = b_data; + + const int size = w * elempack; + + batchnorm_bf16s_per_element_sse(ptr, aptr, bptr, size, opt.num_threads); + } + + if (dims == 2) + { + const int size = w * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + const float* aptr = (const float*)a_data + i * elempack; + const float* bptr = (const float*)b_data + i * elempack; + + batchnorm_bf16s_sse(ptr, aptr, bptr, size, elempack); + } + } + + if (dims == 3 || dims == 4) + { + const int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + const float* aptr = (const float*)a_data + q * elempack; + const float* bptr = (const float*)b_data + q * elempack; + + batchnorm_bf16s_sse(ptr, aptr, bptr, size, elempack); + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/batchnorm_x86.h b/src/layer/x86/batchnorm_x86.h index 752c840b3af..7ffa814f3eb 100644 --- a/src/layer/x86/batchnorm_x86.h +++ b/src/layer/x86/batchnorm_x86.h @@ -14,6 +14,11 @@ class BatchNorm_x86 : public BatchNorm BatchNorm_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/batchnorm_x86_avx512bf16.cpp b/src/layer/x86/batchnorm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..03748a25fe3 --- /dev/null +++ b/src/layer/x86/batchnorm_x86_avx512bf16.cpp @@ -0,0 +1,22 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "batchnorm_bf16s.h" + +void batchnorm_bf16s_sse_avx512bf16(unsigned short* ptr, const float* a, const float* b, int size, int elempack) +{ + batchnorm_bf16s_sse(ptr, a, b, size, elempack); +} + +void batchnorm_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* a, const float* b, int size, int num_threads) +{ + batchnorm_bf16s_per_element_sse(ptr, a, b, size, num_threads); +} + +} // namespace ncnn diff --git a/src/layer/x86/prelu_bf16s.h b/src/layer/x86/prelu_bf16s.h new file mode 100644 index 00000000000..40e58b452df --- /dev/null +++ b/src/layer/x86/prelu_bf16s.h @@ -0,0 +1,213 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void prelu_bf16s_sse_avx512bf16(unsigned short* ptr, const float* slope, int size, int elempack); +void prelu_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* slope, int size, int num_threads); +void prelu_bf16s_single_slope_sse_avx512bf16(unsigned short* ptr, float slope, int size, int num_threads); +#endif + +static void prelu_bf16s_sse(unsigned short* ptr, const float* slope, int size, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + prelu_bf16s_sse_avx512bf16(ptr, slope, size, elempack); + return; + } +#endif + +#if __SSE2__ + __m128 _slope128 = (elempack == 4) ? _mm_loadu_ps(slope) : _mm_set1_ps(slope[0]); + __m128 _zero = _mm_setzero_ps(); +#if __AVX__ + __m256 _slope256 = (elempack == 8) ? _mm256_loadu_ps(slope) : combine4x2_ps(_slope128, _slope128); + __m256 _zero_avx = _mm256_setzero_ps(); +#if __AVX512F__ + __m512 _slope512 = (elempack == 16) ? _mm512_loadu_ps(slope) : combine8x2_ps(_slope256, _slope256); + __m512 _zero_avx512 = _mm512_setzero_ps(); +#endif +#endif +#endif + float s = slope[0]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __mmask16 _mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_LT_OQ); + __m512 _ps = _mm512_mul_ps(_p, _slope512); + _p = _mm512_mask_mov_ps(_p, _mask, _ps); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _ps = _mm256_mul_ps(_p, _slope256); + _p = _mm256_blendv_ps(_p, _ps, _mm256_cmp_ps(_p, _zero_avx, _CMP_LT_OQ)); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _ps = _mm_mul_ps(_p, _slope128); + __m128 _mask = _mm_cmplt_ps(_p, _zero); + _p = _mm_or_ps(_mm_andnot_ps(_mask, _p), _mm_and_ps(_mask, _ps)); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < 0.f) + v *= s; + *ptr = float32_to_bfloat16(v); + ptr++; + } +} + +static void prelu_bf16s_per_element_sse(unsigned short* ptr, const float* slope, int size, int num_threads) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + prelu_bf16s_per_element_sse_avx512bf16(ptr, slope, size, num_threads); + return; + } +#endif + + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + __m512 _zero_avx512 = _mm512_setzero_ps(); + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + i))); + __m512 _slope = _mm512_loadu_ps(slope + i); + __mmask16 _mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_LT_OQ); + __m512 _ps = _mm512_mul_ps(_p, _slope); + _p = _mm512_mask_mov_ps(_p, _mask, _ps); + _mm256_storeu_si256((__m256i*)(ptr + i), float2bfloat_avx512(_p)); + } + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _zero_avx = _mm256_setzero_ps(); + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + i))); + __m256 _slope = _mm256_loadu_ps(slope + i); + __m256 _ps = _mm256_mul_ps(_p, _slope); + _p = _mm256_blendv_ps(_p, _ps, _mm256_cmp_ps(_p, _zero_avx, _CMP_LT_OQ)); + _mm_storeu_si128((__m128i*)(ptr + i), float2bfloat_avx(_p)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _zero = _mm_setzero_ps(); + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + i))); + __m128 _slope = _mm_loadu_ps(slope + i); + __m128 _ps = _mm_mul_ps(_p, _slope); + __m128 _mask = _mm_cmplt_ps(_p, _zero); + _p = _mm_or_ps(_mm_andnot_ps(_mask, _p), _mm_and_ps(_mask, _ps)); + _mm_storel_epi64((__m128i*)(ptr + i), float2bfloat_sse(_p, _p)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(num_threads) + for (int i = remain_size_start; i < size; i++) + { + float v = bfloat16_to_float32(ptr[i]); + if (v < 0.f) + v *= slope[i]; + ptr[i] = float32_to_bfloat16(v); + } +} + +static void prelu_bf16s_single_slope_sse(unsigned short* ptr, float slope, int size, int num_threads) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + prelu_bf16s_single_slope_sse_avx512bf16(ptr, slope, size, num_threads); + return; + } +#endif + + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + __m512 _zero_avx512 = _mm512_setzero_ps(); + __m512 _slope512 = _mm512_set1_ps(slope); + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + i))); + __mmask16 _mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_LT_OQ); + __m512 _ps = _mm512_mul_ps(_p, _slope512); + _p = _mm512_mask_mov_ps(_p, _mask, _ps); + _mm256_storeu_si256((__m256i*)(ptr + i), float2bfloat_avx512(_p)); + } + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _zero_avx = _mm256_setzero_ps(); + __m256 _slope256 = _mm256_set1_ps(slope); + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + i))); + __m256 _ps = _mm256_mul_ps(_p, _slope256); + _p = _mm256_blendv_ps(_p, _ps, _mm256_cmp_ps(_p, _zero_avx, _CMP_LT_OQ)); + _mm_storeu_si128((__m128i*)(ptr + i), float2bfloat_avx(_p)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _zero = _mm_setzero_ps(); + __m128 _slope128 = _mm_set1_ps(slope); + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + i))); + __m128 _ps = _mm_mul_ps(_p, _slope128); + __m128 _mask = _mm_cmplt_ps(_p, _zero); + _p = _mm_or_ps(_mm_andnot_ps(_mask, _p), _mm_and_ps(_mask, _ps)); + _mm_storel_epi64((__m128i*)(ptr + i), float2bfloat_sse(_p, _p)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(num_threads) + for (int i = remain_size_start; i < size; i++) + { + float v = bfloat16_to_float32(ptr[i]); + if (v < 0.f) + v *= slope; + ptr[i] = float32_to_bfloat16(v); + } +} diff --git a/src/layer/x86/prelu_x86.cpp b/src/layer/x86/prelu_x86.cpp index aafb554e039..5063af6bd57 100644 --- a/src/layer/x86/prelu_x86.cpp +++ b/src/layer/x86/prelu_x86.cpp @@ -10,18 +10,32 @@ #endif // __AVX__ #endif // __SSE2__ #include "x86_activation.h" +#include "x86_usability.h" +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "prelu_bf16s.h" +#endif + PReLU_x86::PReLU_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int PReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + int dims = bottom_top_blob.dims; int w = bottom_top_blob.w; int h = bottom_top_blob.h; @@ -233,4 +247,66 @@ int PReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int PReLU_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + unsigned short* ptr = bottom_top_blob; + const int size = w * elempack; + + if (num_slope > 1) + { + prelu_bf16s_per_element_sse(ptr, (const float*)slope_data, size, opt.num_threads); + } + else + { + prelu_bf16s_single_slope_sse(ptr, slope_data[0], size, opt.num_threads); + } + } + + if (dims == 2) + { + const int size = w * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + + float slope = num_slope > 1 ? slope_data[i] : slope_data[0]; + const float* sptr = num_slope > 1 ? (const float*)slope_data + i * elempack : &slope; + int ep = num_slope > 1 ? elempack : 1; + + prelu_bf16s_sse(ptr, sptr, size, ep); + } + } + + if (dims == 3) + { + const int size = w * h * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + float slope = num_slope > 1 ? slope_data[q] : slope_data[0]; + const float* sptr = num_slope > 1 ? (const float*)slope_data + q * elempack : &slope; + int ep = num_slope > 1 ? elempack : 1; + + prelu_bf16s_sse(ptr, sptr, size, ep); + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/prelu_x86.h b/src/layer/x86/prelu_x86.h index 2f33524c14f..ccd190fd8cb 100644 --- a/src/layer/x86/prelu_x86.h +++ b/src/layer/x86/prelu_x86.h @@ -14,6 +14,11 @@ class PReLU_x86 : public PReLU PReLU_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/prelu_x86_avx512bf16.cpp b/src/layer/x86/prelu_x86_avx512bf16.cpp new file mode 100644 index 00000000000..73a1e95a5e7 --- /dev/null +++ b/src/layer/x86/prelu_x86_avx512bf16.cpp @@ -0,0 +1,27 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "prelu_bf16s.h" + +void prelu_bf16s_sse_avx512bf16(unsigned short* ptr, const float* slope, int size, int elempack) +{ + prelu_bf16s_sse(ptr, slope, size, elempack); +} + +void prelu_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* slope, int size, int num_threads) +{ + prelu_bf16s_per_element_sse(ptr, slope, size, num_threads); +} + +void prelu_bf16s_single_slope_sse_avx512bf16(unsigned short* ptr, float slope, int size, int num_threads) +{ + prelu_bf16s_single_slope_sse(ptr, slope, size, num_threads); +} + +} // namespace ncnn diff --git a/src/layer/x86/scale_bf16s.h b/src/layer/x86/scale_bf16s.h new file mode 100644 index 00000000000..10fca67fdd1 --- /dev/null +++ b/src/layer/x86/scale_bf16s.h @@ -0,0 +1,244 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void scale_bf16s_sse_avx512bf16(unsigned short* ptr, const float* scale, const float* bias, int size, int elempack); +void scale_bf16s_no_bias_sse_avx512bf16(unsigned short* ptr, const float* scale, int size, int elempack); +void scale_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* scale, const float* bias, int size, int num_threads); +void scale_bf16s_no_bias_per_element_sse_avx512bf16(unsigned short* ptr, const float* scale, int size, int num_threads); +#endif + +static void scale_bf16s_sse(unsigned short* ptr, const float* scale, const float* bias, int size, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + scale_bf16s_sse_avx512bf16(ptr, scale, bias, size, elempack); + return; + } +#endif + +#if __SSE2__ + __m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale) : _mm_set1_ps(scale[0]); + __m128 _b128 = (elempack == 4) ? _mm_loadu_ps(bias) : _mm_set1_ps(bias[0]); +#if __AVX__ + __m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale) : combine4x2_ps(_s128, _s128); + __m256 _b256 = (elempack == 8) ? _mm256_loadu_ps(bias) : combine4x2_ps(_b128, _b128); +#if __AVX512F__ + __m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale) : combine8x2_ps(_s256, _s256); + __m512 _b512 = (elempack == 16) ? _mm512_loadu_ps(bias) : combine8x2_ps(_b256, _b256); +#endif +#endif +#endif + float s = scale[0]; + float b = bias[0]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmadd_ps(_p, _s512, _b512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmadd_ps(_p, _s256, _b256); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmadd_ps(_p, _s128, _b128); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * s + b); + ptr++; + } +} + +static void scale_bf16s_no_bias_sse(unsigned short* ptr, const float* scale, int size, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + scale_bf16s_no_bias_sse_avx512bf16(ptr, scale, size, elempack); + return; + } +#endif + +#if __SSE2__ + __m128 _s128 = (elempack == 4) ? _mm_loadu_ps(scale) : _mm_set1_ps(scale[0]); +#if __AVX__ + __m256 _s256 = (elempack == 8) ? _mm256_loadu_ps(scale) : combine4x2_ps(_s128, _s128); +#if __AVX512F__ + __m512 _s512 = (elempack == 16) ? _mm512_loadu_ps(scale) : combine8x2_ps(_s256, _s256); +#endif +#endif +#endif + float s = scale[0]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_mul_ps(_p, _s512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_mul_ps(_p, _s256); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _s128); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * s); + ptr++; + } +} + +static void scale_bf16s_per_element_sse(unsigned short* ptr, const float* scale, const float* bias, int size, int num_threads) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + scale_bf16s_per_element_sse_avx512bf16(ptr, scale, bias, size, num_threads); + return; + } +#endif + + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + i))); + __m512 _s = _mm512_loadu_ps(scale + i); + __m512 _bias = _mm512_loadu_ps(bias + i); + _p = _mm512_fmadd_ps(_p, _s, _bias); + _mm256_storeu_si256((__m256i*)(ptr + i), float2bfloat_avx512(_p)); + } + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + i))); + __m256 _s = _mm256_loadu_ps(scale + i); + __m256 _bias = _mm256_loadu_ps(bias + i); + _p = _mm256_comp_fmadd_ps(_p, _s, _bias); + _mm_storeu_si128((__m128i*)(ptr + i), float2bfloat_avx(_p)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + i))); + __m128 _s = _mm_loadu_ps(scale + i); + __m128 _bias = _mm_loadu_ps(bias + i); + _p = _mm_comp_fmadd_ps(_p, _s, _bias); + _mm_storel_epi64((__m128i*)(ptr + i), float2bfloat_sse(_p, _p)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(num_threads) + for (int i = remain_size_start; i < size; i++) + { + ptr[i] = float32_to_bfloat16(bfloat16_to_float32(ptr[i]) * scale[i] + bias[i]); + } +} + +static void scale_bf16s_no_bias_per_element_sse(unsigned short* ptr, const float* scale, int size, int num_threads) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + scale_bf16s_no_bias_per_element_sse_avx512bf16(ptr, scale, size, num_threads); + return; + } +#endif + + int nn_size = 0; + int remain_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_size = (size - remain_size_start) / 16; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + i))); + __m512 _s = _mm512_loadu_ps(scale + i); + _p = _mm512_mul_ps(_p, _s); + _mm256_storeu_si256((__m256i*)(ptr + i), float2bfloat_avx512(_p)); + } + remain_size_start += nn_size * 16; +#endif // __AVX512F__ + nn_size = (size - remain_size_start) / 8; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 8; + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + i))); + __m256 _s = _mm256_loadu_ps(scale + i); + _p = _mm256_mul_ps(_p, _s); + _mm_storeu_si128((__m128i*)(ptr + i), float2bfloat_avx(_p)); + } + remain_size_start += nn_size * 8; +#endif // __AVX__ + nn_size = (size - remain_size_start) / 4; + #pragma omp parallel for num_threads(num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 4; + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + i))); + __m128 _s = _mm_loadu_ps(scale + i); + _p = _mm_mul_ps(_p, _s); + _mm_storel_epi64((__m128i*)(ptr + i), float2bfloat_sse(_p, _p)); + } + remain_size_start += nn_size * 4; +#endif // __SSE2__ + #pragma omp parallel for num_threads(num_threads) + for (int i = remain_size_start; i < size; i++) + { + ptr[i] = float32_to_bfloat16(bfloat16_to_float32(ptr[i]) * scale[i]); + } +} diff --git a/src/layer/x86/scale_x86.cpp b/src/layer/x86/scale_x86.cpp index 42147e972fc..7ceaf6fe0cc 100644 --- a/src/layer/x86/scale_x86.cpp +++ b/src/layer/x86/scale_x86.cpp @@ -10,13 +10,22 @@ #endif // __AVX__ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "scale_bf16s.h" +#endif + Scale_x86::Scale_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Scale_x86::forward_inplace(std::vector& bottom_top_blobs, const Option& opt) const @@ -24,6 +33,11 @@ int Scale_x86::forward_inplace(std::vector& bottom_top_blobs, const Option& Mat& bottom_top_blob = bottom_top_blobs[0]; const Mat& scale_blob = bottom_top_blobs[1]; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blobs, opt); +#endif + const int w = bottom_top_blob.w; const int h = bottom_top_blob.h; const int d = bottom_top_blob.d; @@ -340,4 +354,103 @@ int Scale_x86::forward_inplace(std::vector& bottom_top_blobs, const Option& return 0; } +#if NCNN_BF16 +int Scale_x86::forward_inplace_bf16s(std::vector& bottom_top_blobs, const Option& opt) const +{ + Mat& bottom_top_blob = bottom_top_blobs[0]; + const Mat& scale_blob = bottom_top_blobs[1]; + + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int d = bottom_top_blob.d; + const int channels = bottom_top_blob.c; + const int dims = bottom_top_blob.dims; + const int elempack = bottom_top_blob.elempack; + + // scale_blob may be bf16 (from second input) or fp32 (from scale_data weight) + const float* scale = 0; + Mat scale_fp32; + if (scale_blob.elembits() == 16) + { + const int scale_data_size = scale_blob.w * scale_blob.elempack; + scale_fp32.create(scale_data_size, 4u, 1, opt.workspace_allocator); + if (scale_fp32.empty()) + return -100; + const unsigned short* src = scale_blob; + float* dst = scale_fp32; + for (int i = 0; i < scale_data_size; i++) + { + dst[i] = bfloat16_to_float32(src[i]); + } + scale = scale_fp32; + } + else + { + scale = scale_blob; + } + const float* bias = bias_data; + + if (dims == 1) + { + unsigned short* ptr = (unsigned short*)bottom_top_blob; + const int size = w * elempack; + + if (bias_term) + { + scale_bf16s_per_element_sse(ptr, scale, bias, size, opt.num_threads); + } + else + { + scale_bf16s_no_bias_per_element_sse(ptr, scale, size, opt.num_threads); + } + } + + if (dims == 2) + { + const int size = w * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + const float* sptr = scale + i * elempack; + + if (bias_term) + { + const float* bptr = bias + i * elempack; + scale_bf16s_sse(ptr, sptr, bptr, size, elempack); + } + else + { + scale_bf16s_no_bias_sse(ptr, sptr, size, elempack); + } + } + } + + if (dims == 3 || dims == 4) + { + const int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + const float* sptr = scale + q * elempack; + + if (bias_term) + { + const float* bptr = bias + q * elempack; + scale_bf16s_sse(ptr, sptr, bptr, size, elempack); + } + else + { + scale_bf16s_no_bias_sse(ptr, sptr, size, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/scale_x86.h b/src/layer/x86/scale_x86.h index e0884615de8..31b78252c45 100644 --- a/src/layer/x86/scale_x86.h +++ b/src/layer/x86/scale_x86.h @@ -14,6 +14,11 @@ class Scale_x86 : public Scale Scale_x86(); virtual int forward_inplace(std::vector& bottom_top_blobs, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(std::vector& bottom_top_blobs, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/scale_x86_avx512bf16.cpp b/src/layer/x86/scale_x86_avx512bf16.cpp new file mode 100644 index 00000000000..236dde407a0 --- /dev/null +++ b/src/layer/x86/scale_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "scale_bf16s.h" + +void scale_bf16s_sse_avx512bf16(unsigned short* ptr, const float* scale, const float* bias, int size, int elempack) +{ + scale_bf16s_sse(ptr, scale, bias, size, elempack); +} + +void scale_bf16s_no_bias_sse_avx512bf16(unsigned short* ptr, const float* scale, int size, int elempack) +{ + scale_bf16s_no_bias_sse(ptr, scale, size, elempack); +} + +void scale_bf16s_per_element_sse_avx512bf16(unsigned short* ptr, const float* scale, const float* bias, int size, int num_threads) +{ + scale_bf16s_per_element_sse(ptr, scale, bias, size, num_threads); +} + +void scale_bf16s_no_bias_per_element_sse_avx512bf16(unsigned short* ptr, const float* scale, int size, int num_threads) +{ + scale_bf16s_no_bias_per_element_sse(ptr, scale, size, num_threads); +} + +} // namespace ncnn diff --git a/src/layer/x86/softmax_bf16s.h b/src/layer/x86/softmax_bf16s.h new file mode 100644 index 00000000000..6bb5444605e --- /dev/null +++ b/src/layer/x86/softmax_bf16s.h @@ -0,0 +1,912 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void softmax_bf16s_sse_avx512bf16(unsigned short* _ptr, int elemcount, int elempack); +void softmax_bf16s_pack1_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr); +void softmax_bf16s_pack4_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr); +void softmax_bf16s_pack8_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr); +void softmax_bf16s_pack16_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr); +#endif + +static void softmax_bf16s_sse(unsigned short* _ptr, int elemcount, int elempack) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + softmax_bf16s_sse_avx512bf16(_ptr, elemcount, elempack); + return; + } +#endif + + const int size = elemcount * elempack; + + // reduce max +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _max_avx512 = _mm512_set1_ps(-FLT_MAX); +#endif // __AVX512F__ + __m256 _max_avx = _mm256_set1_ps(-FLT_MAX); +#endif // __AVX__ + __m128 _max = _mm_set1_ps(-FLT_MAX); +#endif // __SSE2__ + float max = -FLT_MAX; + { + const unsigned short* ptr = _ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _max_avx512 = _mm512_max_ps(_max_avx512, _p); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _max_avx = _mm256_max_ps(_max_avx, _p); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _max = _mm_max_ps(_max, _p); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + max = std::max(max, bfloat16_to_float32(*ptr++)); + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 8) + { + { + __m256 _max0 = _mm512_castps512_ps256(_max_avx512); + __m256 _max1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_max_avx512), 1)); + _max_avx = _mm256_max_ps(_max_avx, _max0); + _max_avx = _mm256_max_ps(_max_avx, _max1); + } + + _max_avx512 = combine8x2_ps(_max_avx, _max_avx); + } +#endif // __AVX512F__ + if (elempack == 4) + { +#if __AVX512F__ + { + __m256 _max0 = _mm512_castps512_ps256(_max_avx512); + __m256 _max1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_max_avx512), 1)); + _max_avx = _mm256_max_ps(_max_avx, _max0); + _max_avx = _mm256_max_ps(_max_avx, _max1); + } +#endif // __AVX512F__ + { + __m128 _max0 = _mm256_castps256_ps128(_max_avx); + __m128 _max1 = _mm256_extractf128_ps(_max_avx, 1); + _max = _mm_max_ps(_max, _max0); + _max = _mm_max_ps(_max, _max1); + } + + _max_avx = combine4x2_ps(_max, _max); +#if __AVX512F__ + _max_avx512 = combine8x2_ps(_max_avx, _max_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + max = std::max(max, _mm512_comp_reduce_max_ps(_max_avx512)); +#endif // __AVX512F__ + max = std::max(max, _mm256_reduce_max_ps(_max_avx)); +#endif // __AVX__ + max = std::max(max, _mm_reduce_max_ps(_max)); + + _max = _mm_set1_ps(max); +#if __AVX__ + _max_avx = combine4x2_ps(_max, _max); +#if __AVX512F__ + _max_avx512 = combine8x2_ps(_max_avx, _max_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + + // reduce exp(x - max) and store back to bf16 +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _sum_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _sum = _mm_set1_ps(0.f); +#endif // __SSE2__ + float sum = 0.f; + { + unsigned short* ptr = _ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_sub_ps(_p, _max_avx512); + _p = exp512_ps(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + _sum_avx512 = _mm512_add_ps(_sum_avx512, _p); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_sub_ps(_p, _max_avx); + _p = exp256_ps(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + _sum_avx = _mm256_add_ps(_sum_avx, _p); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_sub_ps(_p, _max); + _p = exp_ps(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + _sum = _mm_add_ps(_sum, _p); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = expf(bfloat16_to_float32(*ptr) - max); + *ptr = float32_to_bfloat16(v); + sum += v; + ptr++; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + _sum_avx512 = _mm512_rcp_nr_ps(_sum_avx512); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _sum0 = _mm512_castps512_ps256(_sum_avx512); + __m256 _sum1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_avx512), 1)); + _sum_avx = _mm256_add_ps(_sum_avx, _sum0); + _sum_avx = _mm256_add_ps(_sum_avx, _sum1); + } +#endif // __AVX512F__ + + _sum_avx = _mm256_rcp_nr_ps(_sum_avx); + +#if __AVX512F__ + _sum_avx512 = combine8x2_ps(_sum_avx, _sum_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _sum0 = _mm512_castps512_ps256(_sum_avx512); + __m256 _sum1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_avx512), 1)); + _sum_avx = _mm256_add_ps(_sum_avx, _sum0); + _sum_avx = _mm256_add_ps(_sum_avx, _sum1); + } +#endif // __AVX512F__ + { + __m128 _sum0 = _mm256_castps256_ps128(_sum_avx); + __m128 _sum1 = _mm256_extractf128_ps(_sum_avx, 1); + _sum = _mm_add_ps(_sum, _sum0); + _sum = _mm_add_ps(_sum, _sum1); + } +#endif // __AVX__ + + _sum = _mm_rcp_nr_ps(_sum); + +#if __AVX__ + _sum_avx = combine4x2_ps(_sum, _sum); +#if __AVX512F__ + _sum_avx512 = combine8x2_ps(_sum_avx, _sum_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + sum += _mm512_comp_reduce_add_ps(_sum_avx512); +#endif // __AVX512F__ + sum += _mm256_reduce_add_ps(_sum_avx); +#endif // __AVX__ + sum += _mm_reduce_add_ps(_sum); +#endif // __SSE2__ + + sum = 1.f / sum; + +#if __SSE2__ + _sum = _mm_set1_ps(sum); +#if __AVX__ + _sum_avx = combine4x2_ps(_sum, _sum); +#if __AVX512F__ + _sum_avx512 = combine8x2_ps(_sum_avx, _sum_avx); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + // div sum + { + unsigned short* ptr = _ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_mul_ps(_p, _sum_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_mul_ps(_p, _sum_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _sum); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * sum); + ptr++; + } + } +} + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void softmax_bf16s_pack16_sse(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + softmax_bf16s_pack16_sse_avx512bf16(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + return; + } +#endif + + // reduce max + for (int i = 0; i < elemcount; i++) + { + const unsigned short* ptr = _ptr + i * stride; + float* maxptr = _maxptr; + + int j = 0; + for (; j < size1; j++) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + *maxptr = std::max(*maxptr, _mm512_comp_reduce_max_ps(_p)); + ptr += 16; + maxptr++; + } + } + + // reduce exp(x - max) + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* maxptr = _maxptr; + float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _max = _mm512_set1_ps(*maxptr); + _p = exp512_ps(_mm512_sub_ps(_p, _max)); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + *sumptr += _mm512_comp_reduce_add_ps(_p); + ptr += 16; + maxptr++; + sumptr++; + } + } + + { + float* sumptr = _sumptr; + int j = 0; + for (; j + 15 < size1; j += 16) + { + __m512 _sum = _mm512_loadu_ps(sumptr); + _sum = _mm512_rcp_nr_ps(_sum); + _mm512_storeu_ps(sumptr, _sum); + sumptr += 16; + } + for (; j + 7 < size1; j += 8) + { + __m256 _sum = _mm256_loadu_ps(sumptr); + _sum = _mm256_rcp_nr_ps(_sum); + _mm256_storeu_ps(sumptr, _sum); + sumptr += 8; + } + for (; j + 3 < size1; j += 4) + { + __m128 _sum = _mm_loadu_ps(sumptr); + _sum = _mm_rcp_nr_ps(_sum); + _mm_storeu_ps(sumptr, _sum); + sumptr += 4; + } + for (; j < size1; j++) + { + *sumptr = 1.f / *sumptr; + sumptr++; + } + } + + // div sum + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _sum = _mm512_set1_ps(*sumptr); + _p = _mm512_mul_ps(_p, _sum); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + sumptr++; + } + } +} +#endif // __AVX512F__ + +static void softmax_bf16s_pack8_sse(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + softmax_bf16s_pack8_sse_avx512bf16(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + return; + } +#endif + + // reduce max + for (int i = 0; i < elemcount; i++) + { + const unsigned short* ptr = _ptr + i * stride; + float* maxptr = _maxptr; + + int j = 0; + for (; j < size1; j++) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + *maxptr = std::max(*maxptr, _mm256_reduce_max_ps(_p)); + ptr += 8; + maxptr++; + } + } + + // reduce exp(x - max) + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* maxptr = _maxptr; + float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _max = _mm256_set1_ps(*maxptr); + _p = exp256_ps(_mm256_sub_ps(_p, _max)); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + *sumptr += _mm256_reduce_add_ps(_p); + ptr += 8; + maxptr++; + sumptr++; + } + } + + { + float* sumptr = _sumptr; + int j = 0; +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _sum = _mm512_loadu_ps(sumptr); + _sum = _mm512_rcp_nr_ps(_sum); + _mm512_storeu_ps(sumptr, _sum); + sumptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _sum = _mm256_loadu_ps(sumptr); + _sum = _mm256_rcp_nr_ps(_sum); + _mm256_storeu_ps(sumptr, _sum); + sumptr += 8; + } + for (; j + 3 < size1; j += 4) + { + __m128 _sum = _mm_loadu_ps(sumptr); + _sum = _mm_rcp_nr_ps(_sum); + _mm_storeu_ps(sumptr, _sum); + sumptr += 4; + } + for (; j < size1; j++) + { + *sumptr = 1.f / *sumptr; + sumptr++; + } + } + + // div sum + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _sum = _mm256_set1_ps(*sumptr); + _p = _mm256_mul_ps(_p, _sum); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + sumptr++; + } + } +} +#endif // __AVX__ + +static void softmax_bf16s_pack4_sse(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + softmax_bf16s_pack4_sse_avx512bf16(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + return; + } +#endif + + // reduce max + for (int i = 0; i < elemcount; i++) + { + const unsigned short* ptr = _ptr + i * stride; + float* maxptr = _maxptr; + + int j = 0; + for (; j < size1; j++) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + *maxptr = std::max(*maxptr, _mm_reduce_max_ps(_p)); + ptr += 4; + maxptr++; + } + } + + // reduce exp(x - max) + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* maxptr = _maxptr; + float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _max = _mm_set1_ps(*maxptr); + _p = exp_ps(_mm_sub_ps(_p, _max)); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + *sumptr += _mm_reduce_add_ps(_p); + ptr += 4; + maxptr++; + sumptr++; + } + } + + { + float* sumptr = _sumptr; + int j = 0; +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _sum = _mm512_loadu_ps(sumptr); + _sum = _mm512_rcp_nr_ps(_sum); + _mm512_storeu_ps(sumptr, _sum); + sumptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _sum = _mm256_loadu_ps(sumptr); + _sum = _mm256_rcp_nr_ps(_sum); + _mm256_storeu_ps(sumptr, _sum); + sumptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size1; j += 4) + { + __m128 _sum = _mm_loadu_ps(sumptr); + _sum = _mm_rcp_nr_ps(_sum); + _mm_storeu_ps(sumptr, _sum); + sumptr += 4; + } + for (; j < size1; j++) + { + *sumptr = 1.f / *sumptr; + sumptr++; + } + } + + // div sum + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* sumptr = _sumptr; + + int j = 0; + for (; j < size1; j++) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _sum = _mm_set1_ps(*sumptr); + _p = _mm_mul_ps(_p, _sum); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + sumptr++; + } + } +} +#endif // __SSE2__ + +static void softmax_bf16s_pack1_sse(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + softmax_bf16s_pack1_sse_avx512bf16(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + return; + } +#endif + + // reduce max + for (int i = 0; i < elemcount; i++) + { + const unsigned short* ptr = _ptr + i * stride; + float* maxptr = _maxptr; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _max = _mm512_loadu_ps(maxptr); + _max = _mm512_max_ps(_max, _p); + _mm512_storeu_ps(maxptr, _max); + ptr += 16; + maxptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _max = _mm256_loadu_ps(maxptr); + _max = _mm256_max_ps(_max, _p); + _mm256_storeu_ps(maxptr, _max); + ptr += 8; + maxptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size1; j += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _max = _mm_loadu_ps(maxptr); + _max = _mm_max_ps(_max, _p); + _mm_storeu_ps(maxptr, _max); + ptr += 4; + maxptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + *maxptr = std::max(*maxptr, bfloat16_to_float32(*ptr)); + ptr++; + maxptr++; + } + } + + // reduce exp(x - max) + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* maxptr = _maxptr; + float* sumptr = _sumptr; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _max = _mm512_loadu_ps(maxptr); + __m512 _sum = _mm512_loadu_ps(sumptr); + _p = _mm512_sub_ps(_p, _max); + _p = exp512_ps(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + _sum = _mm512_add_ps(_sum, _p); + _mm512_storeu_ps(sumptr, _sum); + ptr += 16; + maxptr += 16; + sumptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _max = _mm256_loadu_ps(maxptr); + __m256 _sum = _mm256_loadu_ps(sumptr); + _p = _mm256_sub_ps(_p, _max); + _p = exp256_ps(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + _sum = _mm256_add_ps(_sum, _p); + _mm256_storeu_ps(sumptr, _sum); + ptr += 8; + maxptr += 8; + sumptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size1; j += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _max = _mm_loadu_ps(maxptr); + __m128 _sum = _mm_loadu_ps(sumptr); + _p = _mm_sub_ps(_p, _max); + _p = exp_ps(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + _sum = _mm_add_ps(_sum, _p); + _mm_storeu_ps(sumptr, _sum); + ptr += 4; + maxptr += 4; + sumptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + float v = expf(bfloat16_to_float32(*ptr) - *maxptr); + *ptr = float32_to_bfloat16(v); + *sumptr += v; + ptr++; + maxptr++; + sumptr++; + } + } + + { + float* sumptr = _sumptr; + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _sum = _mm512_loadu_ps(sumptr); + _sum = _mm512_rcp_nr_ps(_sum); + _mm512_storeu_ps(sumptr, _sum); + sumptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _sum = _mm256_loadu_ps(sumptr); + _sum = _mm256_rcp_nr_ps(_sum); + _mm256_storeu_ps(sumptr, _sum); + sumptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size1; j += 4) + { + __m128 _sum = _mm_loadu_ps(sumptr); + _sum = _mm_rcp_nr_ps(_sum); + _mm_storeu_ps(sumptr, _sum); + sumptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + *sumptr = 1.f / *sumptr; + sumptr++; + } + } + + // div sum + for (int i = 0; i < elemcount; i++) + { + unsigned short* ptr = _ptr + i * stride; + const float* sumptr = _sumptr; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size1; j += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _sum = _mm512_loadu_ps(sumptr); + _p = _mm512_mul_ps(_p, _sum); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + sumptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size1; j += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _sum = _mm256_loadu_ps(sumptr); + _p = _mm256_mul_ps(_p, _sum); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + sumptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size1; j += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _sum = _mm_loadu_ps(sumptr); + _p = _mm_mul_ps(_p, _sum); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + sumptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + *ptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * *sumptr); + ptr++; + sumptr++; + } + } +} + +static void softmax_bf16s_sse_dispatch(unsigned short* _ptr, int elemcount, int elempack, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ + // init max + { + float* maxptr = _maxptr; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _negmax_avx512 = _mm512_set1_ps(-FLT_MAX); + for (; j + 15 < size1; j += 16) + { + _mm512_storeu_ps(maxptr, _negmax_avx512); + maxptr += 16; + } +#endif // __AVX512F__ + __m256 _negmax_avx = _mm256_set1_ps(-FLT_MAX); + for (; j + 7 < size1; j += 8) + { + _mm256_storeu_ps(maxptr, _negmax_avx); + maxptr += 8; + } +#endif // __AVX__ + __m128 _negmax = _mm_set1_ps(-FLT_MAX); + for (; j + 3 < size1; j += 4) + { + _mm_storeu_ps(maxptr, _negmax); + maxptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + *maxptr++ = -FLT_MAX; + } + } + + // init sum + { + float* sumptr = _sumptr; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _zero_avx512 = _mm512_set1_ps(0.f); + for (; j + 15 < size1; j += 16) + { + _mm512_storeu_ps(sumptr, _zero_avx512); + sumptr += 16; + } +#endif // __AVX512F__ + __m256 _zero_avx = _mm256_set1_ps(0.f); + for (; j + 7 < size1; j += 8) + { + _mm256_storeu_ps(sumptr, _zero_avx); + sumptr += 8; + } +#endif // __AVX__ + __m128 _zero = _mm_set1_ps(0.f); + for (; j + 3 < size1; j += 4) + { + _mm_storeu_ps(sumptr, _zero); + sumptr += 4; + } +#endif // __SSE2__ + for (; j < size1; j++) + { + *sumptr++ = 0.f; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + softmax_bf16s_pack16_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + } +#endif // __AVX512F__ + if (elempack == 8) + { + softmax_bf16s_pack8_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + } +#endif // __AVX__ + if (elempack == 4) + { + softmax_bf16s_pack4_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + } +#endif // __SSE2__ + if (elempack == 1) + { + softmax_bf16s_pack1_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); + } +} diff --git a/src/layer/x86/softmax_x86.cpp b/src/layer/x86/softmax_x86.cpp index ec29c4de97f..b2e1bdc4093 100644 --- a/src/layer/x86/softmax_x86.cpp +++ b/src/layer/x86/softmax_x86.cpp @@ -22,33 +22,8 @@ namespace ncnn { -#if __SSE2__ -static NCNN_FORCEINLINE __m128 _mm_rcp_nr_ps(const __m128& x) -{ - __m128 y = _mm_rcp_ps(x); // approx - __m128 t = _mm_comp_fnmadd_ps(x, y, _mm_set1_ps(2.0f)); // (2 - x*y) - y = _mm_mul_ps(y, t); - return y; // 1 NR step -} -#endif - -#if __AVX__ -static NCNN_FORCEINLINE __m256 _mm256_rcp_nr_ps(const __m256& x) -{ - __m256 y = _mm256_rcp_ps(x); - __m256 t = _mm256_comp_fnmadd_ps(x, y, _mm256_set1_ps(2.0f)); - y = _mm256_mul_ps(y, t); - return y; -} -#endif - -#if __AVX512F__ -static NCNN_FORCEINLINE __m512 _mm512_rcp_nr_ps(const __m512& x) -{ - __m512 y = _mm512_rcp14_ps(x); - __m512 t = _mm512_fnmadd_ps(x, y, _mm512_set1_ps(2.0f)); - return _mm512_mul_ps(y, t); -} +#if NCNN_BF16 +#include "softmax_bf16s.h" #endif Softmax_x86::Softmax_x86() @@ -56,6 +31,9 @@ Softmax_x86::Softmax_x86() #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void softmax(float* _ptr, int elemcount, int elempack) @@ -1670,6 +1648,11 @@ static void softmax(float* _ptr, int elemcount, int elempack, size_t stride, int int Softmax_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + const int dims = bottom_top_blob.dims; const int w = bottom_top_blob.w; const int h = bottom_top_blob.h; @@ -1833,4 +1816,171 @@ int Softmax_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int Softmax_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int d = bottom_top_blob.d; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + const int positive_axis = axis < 0 ? dims + axis : axis; + + if (dims == 1) // positive_axis == 0 + { + unsigned short* ptr = bottom_top_blob; + + const int size = w * elempack; + + softmax_bf16s_sse(ptr, size, 1); + } + + if (dims == 2 && positive_axis == 0) + { + const int size = w; + const int sizen = (size + (opt.num_threads - 1)) / opt.num_threads; + const size_t stride = (size_t)w * elempack; + + Mat maxsum(sizen, 2, opt.num_threads, 4u, opt.workspace_allocator); + if (maxsum.empty()) + return -100; + + const int nn_size = (size + sizen - 1) / sizen; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + const int i = ii * sizen; + const int size1 = std::min(sizen, size - i); + + float* maxsumptr = maxsum.channel(get_omp_thread_num()); + float* maxptr = maxsumptr; + float* sumptr = maxptr + sizen; + + unsigned short* ptr = (unsigned short*)bottom_top_blob + i * elempack; + + softmax_bf16s_sse_dispatch(ptr, h, elempack, stride, size1, maxptr, sumptr); + } + } + + if (dims == 2 && positive_axis == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + + softmax_bf16s_sse(ptr, w, elempack); + } + } + + if ((dims == 3 || dims == 4) && positive_axis == 0) + { + const int size = w * h * d; + const int sizen = (size + (opt.num_threads - 1)) / opt.num_threads; + const size_t stride = bottom_top_blob.cstep * elempack; + + Mat maxsum(sizen, 2, opt.num_threads, 4u, opt.workspace_allocator); + if (maxsum.empty()) + return -100; + + const int nn_size = (size + sizen - 1) / sizen; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + const int i = ii * sizen; + const int size1 = std::min(sizen, size - i); + + float* maxsumptr = maxsum.channel(get_omp_thread_num()); + float* maxptr = maxsumptr; + float* sumptr = maxptr + sizen; + + unsigned short* ptr = (unsigned short*)bottom_top_blob + i * elempack; + + softmax_bf16s_sse_dispatch(ptr, channels, elempack, stride, size1, maxptr, sumptr); + } + } + + if ((dims == 3 && positive_axis == 1) || (dims == 4 && positive_axis == 2)) + { + const int size = w * elempack; + + Mat maxsum(size, 2, opt.num_threads, 4u, opt.workspace_allocator); + if (maxsum.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < d; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).depth(i); + + float* maxsumptr = maxsum.channel(get_omp_thread_num()); + float* maxptr = maxsumptr; + float* sumptr = maxptr + size; + + softmax_bf16s_sse_dispatch(ptr, h, 1, size, size, maxptr, sumptr); + } + } + } + + if (dims == 3 && positive_axis == 2) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < h; i++) + { + softmax_bf16s_sse(ptr, w, elempack); + ptr += w * elempack; + } + } + } + + if (dims == 4 && positive_axis == 1) + { + const int size = w * h * elempack; + + Mat maxsum(size, 2, opt.num_threads, 4u, opt.workspace_allocator); + if (maxsum.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + float* maxsumptr = maxsum.channel(get_omp_thread_num()); + float* maxptr = maxsumptr; + float* sumptr = maxptr + size; + + softmax_bf16s_sse_dispatch(ptr, d, 1, size, size, maxptr, sumptr); + } + } + + if (dims == 4 && positive_axis == 3) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < d; i++) + { + for (int j = 0; j < h; j++) + { + softmax_bf16s_sse(ptr, w, elempack); + ptr += w * elempack; + } + } + } + } + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/softmax_x86.h b/src/layer/x86/softmax_x86.h index 0e3da9e0a9e..c193aaa60b9 100644 --- a/src/layer/x86/softmax_x86.h +++ b/src/layer/x86/softmax_x86.h @@ -14,6 +14,11 @@ class Softmax_x86 : public Softmax Softmax_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/softmax_x86_avx512bf16.cpp b/src/layer/x86/softmax_x86_avx512bf16.cpp new file mode 100644 index 00000000000..2f69d47749e --- /dev/null +++ b/src/layer/x86/softmax_x86_avx512bf16.cpp @@ -0,0 +1,54 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "softmax_x86.h" + +#include + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "softmax_bf16s.h" + +void softmax_bf16s_sse_avx512bf16(unsigned short* _ptr, int elemcount, int elempack) +{ + softmax_bf16s_sse(_ptr, elemcount, elempack); +} + +void softmax_bf16s_pack1_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ + softmax_bf16s_pack1_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); +} + +void softmax_bf16s_pack4_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ + softmax_bf16s_pack4_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); +} + +void softmax_bf16s_pack8_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ + softmax_bf16s_pack8_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); +} + +void softmax_bf16s_pack16_sse_avx512bf16(unsigned short* _ptr, int elemcount, size_t stride, int size1, float* _maxptr, float* _sumptr) +{ + softmax_bf16s_pack16_sse(_ptr, elemcount, stride, size1, _maxptr, _sumptr); +} + +} // namespace ncnn diff --git a/src/layer/x86/swish_bf16s.h b/src/layer/x86/swish_bf16s.h new file mode 100644 index 00000000000..2087f56452a --- /dev/null +++ b/src/layer/x86/swish_bf16s.h @@ -0,0 +1,72 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void swish_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void swish_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + swish_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _one_avx512 = _mm512_set1_ps(1.f); + __m512 _zero_avx512 = _mm512_setzero_ps(); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_div_ps(_p, _mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _p)))); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + __m256 _one_avx = _mm256_set1_ps(1.f); + __m256 _zero_avx = _mm256_setzero_ps(); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_div_ps(_p, _mm256_add_ps(_one_avx, exp256_ps(_mm256_sub_ps(_zero_avx, _p)))); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + __m128 _one = _mm_set1_ps(1.f); + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_div_ps(_p, _mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _p)))); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = v / (1.f + expf(-v)); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/swish_x86.cpp b/src/layer/x86/swish_x86.cpp index 2aaf80ef8ec..1866f76b56b 100644 --- a/src/layer/x86/swish_x86.cpp +++ b/src/layer/x86/swish_x86.cpp @@ -15,13 +15,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "swish_bf16s.h" +#endif + Swish_x86::Swish_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Swish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -33,6 +44,11 @@ int Swish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int elempack = bottom_top_blob.elempack; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -82,4 +98,13 @@ int Swish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int Swish_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + swish_bf16s(bottom_top_blob, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/swish_x86.h b/src/layer/x86/swish_x86.h index 2a2ad564123..dbfdf187623 100644 --- a/src/layer/x86/swish_x86.h +++ b/src/layer/x86/swish_x86.h @@ -14,6 +14,11 @@ class Swish_x86 : public Swish Swish_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/swish_x86_avx512bf16.cpp b/src/layer/x86/swish_x86_avx512bf16.cpp new file mode 100644 index 00000000000..e95f23b5598 --- /dev/null +++ b/src/layer/x86/swish_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "swish_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "swish_bf16s.h" + +void swish_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + swish_bf16s(a, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index bae92af8eac..0da0b001f93 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -443,6 +443,14 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128 #endif } +static NCNN_FORCEINLINE __m128 _mm_rcp_nr_ps(const __m128& x) +{ + __m128 y = _mm_rcp_ps(x); + __m128 t = _mm_comp_fnmadd_ps(x, y, _mm_set1_ps(2.0f)); + y = _mm_mul_ps(y, t); + return y; +} + static NCNN_FORCEINLINE __m128i _mm_comp_dpwssd_epi32(const __m128i& src, const __m128i& a, const __m128i& b) { #if __AVX512VNNI__ @@ -497,6 +505,14 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m #endif } +static NCNN_FORCEINLINE __m256 _mm256_rcp_nr_ps(const __m256& x) +{ + __m256 y = _mm256_rcp_ps(x); + __m256 t = _mm256_comp_fnmadd_ps(x, y, _mm256_set1_ps(2.0f)); + y = _mm256_mul_ps(y, t); + return y; +} + static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c) { return _mm256_comp_fmadd_ps(b, _mm256_set1_ps(c), a); @@ -1732,6 +1748,13 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(const __m512& x) return _mm_cvtss_f32(x32); } +static NCNN_FORCEINLINE __m512 _mm512_rcp_nr_ps(const __m512& x) +{ + __m512 y = _mm512_rcp14_ps(x); + __m512 t = _mm512_fnmadd_ps(x, y, _mm512_set1_ps(2.0f)); + return _mm512_mul_ps(y, t); +} + static NCNN_FORCEINLINE __m512 combine8x2_ps(const __m256& a, const __m256& b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); @@ -1818,7 +1841,6 @@ static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m5 #endif return _v; } - #endif // __AVX512F__ #endif // __AVX2__ #endif // __AVX__ From 8b5f2bf1b1b2676ce08201e6bc8a02a2678e784a Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Mar 2026 19:10:50 +0800 Subject: [PATCH 19/36] x86 interp optimization (#6597) --- src/layer/x86/interp_bicubic.h | 525 +++++++++++++--- src/layer/x86/interp_bicubic_pack16.h | 30 +- src/layer/x86/interp_bicubic_pack4.h | 363 ++++++++++-- src/layer/x86/interp_bicubic_pack8.h | 201 +++++-- src/layer/x86/interp_bilinear.h | 257 ++++++-- src/layer/x86/interp_bilinear_pack16.h | 20 +- src/layer/x86/interp_bilinear_pack4.h | 139 ++++- src/layer/x86/interp_bilinear_pack8.h | 63 +- src/layer/x86/interp_x86.cpp | 789 ++++++------------------- src/layer/x86/interp_x86_avx2.cpp | 23 + tests/test_interp.cpp | 202 ++++++- 11 files changed, 1754 insertions(+), 858 deletions(-) create mode 100644 src/layer/x86/interp_x86_avx2.cpp diff --git a/src/layer/x86/interp_bicubic.h b/src/layer/x86/interp_bicubic.h index 77097597a8f..5588bf8a756 100644 --- a/src/layer/x86/interp_bicubic.h +++ b/src/layer/x86/interp_bicubic.h @@ -1,6 +1,10 @@ // Copyright 2022 Tencent // SPDX-License-Identifier: BSD-3-Clause +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void resize_bicubic_image_avx2(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs); +#endif + static inline void interpolate_cubic(float fx, float* coeffs) { const float A = -0.75f; @@ -74,8 +78,79 @@ static void cubic_coeffs(int w, int outw, int* xofs, float* alpha, int align_cor } } +static void vresize_bicubic(const float* rows0, const float* rows1, const float* rows2, const float* rows3, float* Dp, int n, float b0, float b1, float b2, float b3) +{ + int nn = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _b0_512 = _mm512_set1_ps(b0); + __m512 _b1_512 = _mm512_set1_ps(b1); + __m512 _b2_512 = _mm512_set1_ps(b2); + __m512 _b3_512 = _mm512_set1_ps(b3); + for (; nn + 15 < n; nn += 16) + { + __m512 _rows0 = _mm512_loadu_ps(rows0 + nn); + __m512 _rows1 = _mm512_loadu_ps(rows1 + nn); + __m512 _rows2 = _mm512_loadu_ps(rows2 + nn); + __m512 _rows3 = _mm512_loadu_ps(rows3 + nn); + __m512 _Dp = _mm512_mul_ps(_rows0, _b0_512); + _Dp = _mm512_fmadd_ps(_rows1, _b1_512, _Dp); + _Dp = _mm512_fmadd_ps(_rows2, _b2_512, _Dp); + _Dp = _mm512_fmadd_ps(_rows3, _b3_512, _Dp); + _mm512_storeu_ps(Dp + nn, _Dp); + } +#endif // __AVX512F__ + __m256 _b0_256 = _mm256_set1_ps(b0); + __m256 _b1_256 = _mm256_set1_ps(b1); + __m256 _b2_256 = _mm256_set1_ps(b2); + __m256 _b3_256 = _mm256_set1_ps(b3); + for (; nn + 7 < n; nn += 8) + { + __m256 _rows0 = _mm256_loadu_ps(rows0 + nn); + __m256 _rows1 = _mm256_loadu_ps(rows1 + nn); + __m256 _rows2 = _mm256_loadu_ps(rows2 + nn); + __m256 _rows3 = _mm256_loadu_ps(rows3 + nn); + __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); + _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); + _Dp = _mm256_comp_fmadd_ps(_rows2, _b2_256, _Dp); + _Dp = _mm256_comp_fmadd_ps(_rows3, _b3_256, _Dp); + _mm256_storeu_ps(Dp + nn, _Dp); + } +#endif // __AVX__ + __m128 _b0_128 = _mm_set1_ps(b0); + __m128 _b1_128 = _mm_set1_ps(b1); + __m128 _b2_128 = _mm_set1_ps(b2); + __m128 _b3_128 = _mm_set1_ps(b3); + for (; nn + 3 < n; nn += 4) + { + __m128 _rows0 = _mm_loadu_ps(rows0 + nn); + __m128 _rows1 = _mm_loadu_ps(rows1 + nn); + __m128 _rows2 = _mm_loadu_ps(rows2 + nn); + __m128 _rows3 = _mm_loadu_ps(rows3 + nn); + __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); + _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); + _Dp = _mm_comp_fmadd_ps(_rows2, _b2_128, _Dp); + _Dp = _mm_comp_fmadd_ps(_rows3, _b3_128, _Dp); + _mm_storeu_ps(Dp + nn, _Dp); + } +#endif // __SSE2__ + for (; nn < n; nn++) + { + Dp[nn] = rows0[nn] * b0 + rows1[nn] * b1 + rows2[nn] * b2 + rows3[nn] * b3; + } +} + static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs) { +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + resize_bicubic_image_avx2(src, dst, alpha, xofs, beta, yofs); + return; + } +#endif + int w = dst.w; int h = dst.h; @@ -111,7 +186,78 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo const float* alphap = alpha; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sxn1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(-1)); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + __m512i _sx2 = _mm512_add_epi32(_sx, _mm512_set1_epi32(2)); + + __m512 _S30 = _mm512_i32gather_ps(_sxn1, S3, sizeof(float)); + __m512 _S31 = _mm512_i32gather_ps(_sx, S3, sizeof(float)); + __m512 _S32 = _mm512_i32gather_ps(_sx1, S3, sizeof(float)); + __m512 _S33 = _mm512_i32gather_ps(_sx2, S3, sizeof(float)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512 _a1 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)), alphap, sizeof(float)); + __m512 _a2 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(2)), alphap, sizeof(float)); + __m512 _a3 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(3)), alphap, sizeof(float)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx, _rows3); + + alphap += 64; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sxn1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(-1)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + __m256i _sx2 = _mm256_add_epi32(_sx, _mm256_set1_epi32(2)); + + __m256 _S30 = _mm256_i32gather_ps(S3, _sxn1, sizeof(float)); + __m256 _S31 = _mm256_i32gather_ps(S3, _sx, sizeof(float)); + __m256 _S32 = _mm256_i32gather_ps(S3, _sx1, sizeof(float)); + __m256 _S33 = _mm256_i32gather_ps(S3, _sx2, sizeof(float)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256 _a1 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(1)), sizeof(float)); + __m256 _a2 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(2)), sizeof(float)); + __m256 _a3 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(3)), sizeof(float)); +#else + __m256 _S30 = _mm256_setr_ps(S3[xofs[dx] - 1], S3[xofs[dx + 1] - 1], S3[xofs[dx + 2] - 1], S3[xofs[dx + 3] - 1], S3[xofs[dx + 4] - 1], S3[xofs[dx + 5] - 1], S3[xofs[dx + 6] - 1], S3[xofs[dx + 7] - 1]); + __m256 _S31 = _mm256_setr_ps(S3[xofs[dx]], S3[xofs[dx + 1]], S3[xofs[dx + 2]], S3[xofs[dx + 3]], S3[xofs[dx + 4]], S3[xofs[dx + 5]], S3[xofs[dx + 6]], S3[xofs[dx + 7]]); + __m256 _S32 = _mm256_setr_ps(S3[xofs[dx] + 1], S3[xofs[dx + 1] + 1], S3[xofs[dx + 2] + 1], S3[xofs[dx + 3] + 1], S3[xofs[dx + 4] + 1], S3[xofs[dx + 5] + 1], S3[xofs[dx + 6] + 1], S3[xofs[dx + 7] + 1]); + __m256 _S33 = _mm256_setr_ps(S3[xofs[dx] + 2], S3[xofs[dx + 1] + 2], S3[xofs[dx + 2] + 2], S3[xofs[dx + 3] + 2], S3[xofs[dx + 4] + 2], S3[xofs[dx + 5] + 2], S3[xofs[dx + 6] + 2], S3[xofs[dx + 7] + 2]); + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[4], alphap[8], alphap[12], alphap[16], alphap[20], alphap[24], alphap[28]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[5], alphap[9], alphap[13], alphap[17], alphap[21], alphap[25], alphap[29]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[6], alphap[10], alphap[14], alphap[18], alphap[22], alphap[26], alphap[30]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[7], alphap[11], alphap[15], alphap[19], alphap[23], alphap[27], alphap[31]); +#endif + + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows3p + dx, _rows3); + + alphap += 32; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; dx < w; dx++) { int sx = xofs[dx]; const float* S3p = S3 + sx; @@ -140,7 +286,107 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo const float* alphap = alpha; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sxn1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(-1)); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + __m512i _sx2 = _mm512_add_epi32(_sx, _mm512_set1_epi32(2)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512 _a1 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)), alphap, sizeof(float)); + __m512 _a2 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(2)), alphap, sizeof(float)); + __m512 _a3 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(3)), alphap, sizeof(float)); + + __m512 _S20 = _mm512_i32gather_ps(_sxn1, S2, sizeof(float)); + __m512 _S21 = _mm512_i32gather_ps(_sx, S2, sizeof(float)); + __m512 _S22 = _mm512_i32gather_ps(_sx1, S2, sizeof(float)); + __m512 _S23 = _mm512_i32gather_ps(_sx2, S2, sizeof(float)); + + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx, _rows2); + + __m512 _S30 = _mm512_i32gather_ps(_sxn1, S3, sizeof(float)); + __m512 _S31 = _mm512_i32gather_ps(_sx, S3, sizeof(float)); + __m512 _S32 = _mm512_i32gather_ps(_sx1, S3, sizeof(float)); + __m512 _S33 = _mm512_i32gather_ps(_sx2, S3, sizeof(float)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx, _rows3); + + alphap += 64; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sxn1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(-1)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + __m256i _sx2 = _mm256_add_epi32(_sx, _mm256_set1_epi32(2)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256 _a1 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(1)), sizeof(float)); + __m256 _a2 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(2)), sizeof(float)); + __m256 _a3 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(3)), sizeof(float)); + + __m256 _S20 = _mm256_i32gather_ps(S2, _sxn1, sizeof(float)); + __m256 _S21 = _mm256_i32gather_ps(S2, _sx, sizeof(float)); + __m256 _S22 = _mm256_i32gather_ps(S2, _sx1, sizeof(float)); + __m256 _S23 = _mm256_i32gather_ps(S2, _sx2, sizeof(float)); +#else + __m256 _S20 = _mm256_setr_ps(S2[xofs[dx] - 1], S2[xofs[dx + 1] - 1], S2[xofs[dx + 2] - 1], S2[xofs[dx + 3] - 1], S2[xofs[dx + 4] - 1], S2[xofs[dx + 5] - 1], S2[xofs[dx + 6] - 1], S2[xofs[dx + 7] - 1]); + __m256 _S21 = _mm256_setr_ps(S2[xofs[dx]], S2[xofs[dx + 1]], S2[xofs[dx + 2]], S2[xofs[dx + 3]], S2[xofs[dx + 4]], S2[xofs[dx + 5]], S2[xofs[dx + 6]], S2[xofs[dx + 7]]); + __m256 _S22 = _mm256_setr_ps(S2[xofs[dx] + 1], S2[xofs[dx + 1] + 1], S2[xofs[dx + 2] + 1], S2[xofs[dx + 3] + 1], S2[xofs[dx + 4] + 1], S2[xofs[dx + 5] + 1], S2[xofs[dx + 6] + 1], S2[xofs[dx + 7] + 1]); + __m256 _S23 = _mm256_setr_ps(S2[xofs[dx] + 2], S2[xofs[dx + 1] + 2], S2[xofs[dx + 2] + 2], S2[xofs[dx + 3] + 2], S2[xofs[dx + 4] + 2], S2[xofs[dx + 5] + 2], S2[xofs[dx + 6] + 2], S2[xofs[dx + 7] + 2]); + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[4], alphap[8], alphap[12], alphap[16], alphap[20], alphap[24], alphap[28]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[5], alphap[9], alphap[13], alphap[17], alphap[21], alphap[25], alphap[29]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[6], alphap[10], alphap[14], alphap[18], alphap[22], alphap[26], alphap[30]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[7], alphap[11], alphap[15], alphap[19], alphap[23], alphap[27], alphap[31]); +#endif + + __m256 _rows2 = _mm256_mul_ps(_S20, _a0); + _rows2 = _mm256_comp_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm256_comp_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm256_comp_fmadd_ps(_S23, _a3, _rows2); + _mm256_storeu_ps(rows2p + dx, _rows2); + +#if __AVX2__ + __m256 _S30 = _mm256_i32gather_ps(S3, _sxn1, sizeof(float)); + __m256 _S31 = _mm256_i32gather_ps(S3, _sx, sizeof(float)); + __m256 _S32 = _mm256_i32gather_ps(S3, _sx1, sizeof(float)); + __m256 _S33 = _mm256_i32gather_ps(S3, _sx2, sizeof(float)); +#else + __m256 _S30 = _mm256_setr_ps(S3[xofs[dx] - 1], S3[xofs[dx + 1] - 1], S3[xofs[dx + 2] - 1], S3[xofs[dx + 3] - 1], S3[xofs[dx + 4] - 1], S3[xofs[dx + 5] - 1], S3[xofs[dx + 6] - 1], S3[xofs[dx + 7] - 1]); + __m256 _S31 = _mm256_setr_ps(S3[xofs[dx]], S3[xofs[dx + 1]], S3[xofs[dx + 2]], S3[xofs[dx + 3]], S3[xofs[dx + 4]], S3[xofs[dx + 5]], S3[xofs[dx + 6]], S3[xofs[dx + 7]]); + __m256 _S32 = _mm256_setr_ps(S3[xofs[dx] + 1], S3[xofs[dx + 1] + 1], S3[xofs[dx + 2] + 1], S3[xofs[dx + 3] + 1], S3[xofs[dx + 4] + 1], S3[xofs[dx + 5] + 1], S3[xofs[dx + 6] + 1], S3[xofs[dx + 7] + 1]); + __m256 _S33 = _mm256_setr_ps(S3[xofs[dx] + 2], S3[xofs[dx + 1] + 2], S3[xofs[dx + 2] + 2], S3[xofs[dx + 3] + 2], S3[xofs[dx + 4] + 2], S3[xofs[dx + 5] + 2], S3[xofs[dx + 6] + 2], S3[xofs[dx + 7] + 2]); +#endif + + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows3p + dx, _rows3); + + alphap += 32; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; dx < w; dx++) { int sx = xofs[dx]; const float* S2p = S2 + sx; @@ -174,7 +420,105 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sxn1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(-1)); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + __m512i _sx2 = _mm512_add_epi32(_sx, _mm512_set1_epi32(2)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512 _a1 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)), alphap, sizeof(float)); + __m512 _a2 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(2)), alphap, sizeof(float)); + __m512 _a3 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(3)), alphap, sizeof(float)); + + __m512 _S10 = _mm512_i32gather_ps(_sxn1, S1, sizeof(float)); + __m512 _S11 = _mm512_i32gather_ps(_sx, S1, sizeof(float)); + __m512 _S12 = _mm512_i32gather_ps(_sx1, S1, sizeof(float)); + __m512 _S13 = _mm512_i32gather_ps(_sx2, S1, sizeof(float)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx, _rows1); + + __m512 _S20 = _mm512_i32gather_ps(_sxn1, S2, sizeof(float)); + __m512 _S21 = _mm512_i32gather_ps(_sx, S2, sizeof(float)); + __m512 _S22 = _mm512_i32gather_ps(_sx1, S2, sizeof(float)); + __m512 _S23 = _mm512_i32gather_ps(_sx2, S2, sizeof(float)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx, _rows2); + + __m512 _S30 = _mm512_i32gather_ps(_sxn1, S3, sizeof(float)); + __m512 _S31 = _mm512_i32gather_ps(_sx, S3, sizeof(float)); + __m512 _S32 = _mm512_i32gather_ps(_sx1, S3, sizeof(float)); + __m512 _S33 = _mm512_i32gather_ps(_sx2, S3, sizeof(float)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx, _rows3); + + alphap += 64; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sxn1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(-1)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + __m256i _sx2 = _mm256_add_epi32(_sx, _mm256_set1_epi32(2)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256 _a1 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(1)), sizeof(float)); + __m256 _a2 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(2)), sizeof(float)); + __m256 _a3 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(3)), sizeof(float)); +#else + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[4], alphap[8], alphap[12], alphap[16], alphap[20], alphap[24], alphap[28]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[5], alphap[9], alphap[13], alphap[17], alphap[21], alphap[25], alphap[29]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[6], alphap[10], alphap[14], alphap[18], alphap[22], alphap[26], alphap[30]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[7], alphap[11], alphap[15], alphap[19], alphap[23], alphap[27], alphap[31]); +#endif + + for (int r = 0; r < 3; r++) + { + const float* Sn = (r == 0) ? S1 : ((r == 1) ? S2 : S3); + float* rowsnp = (r == 0) ? rows1p : ((r == 1) ? rows2p : rows3p); + +#if __AVX2__ + __m256 _Sn0 = _mm256_i32gather_ps(Sn, _sxn1, sizeof(float)); + __m256 _Sn1 = _mm256_i32gather_ps(Sn, _sx, sizeof(float)); + __m256 _Sn2 = _mm256_i32gather_ps(Sn, _sx1, sizeof(float)); + __m256 _Sn3 = _mm256_i32gather_ps(Sn, _sx2, sizeof(float)); +#else + __m256 _Sn0 = _mm256_setr_ps(Sn[xofs[dx] - 1], Sn[xofs[dx + 1] - 1], Sn[xofs[dx + 2] - 1], Sn[xofs[dx + 3] - 1], Sn[xofs[dx + 4] - 1], Sn[xofs[dx + 5] - 1], Sn[xofs[dx + 6] - 1], Sn[xofs[dx + 7] - 1]); + __m256 _Sn1 = _mm256_setr_ps(Sn[xofs[dx]], Sn[xofs[dx + 1]], Sn[xofs[dx + 2]], Sn[xofs[dx + 3]], Sn[xofs[dx + 4]], Sn[xofs[dx + 5]], Sn[xofs[dx + 6]], Sn[xofs[dx + 7]]); + __m256 _Sn2 = _mm256_setr_ps(Sn[xofs[dx] + 1], Sn[xofs[dx + 1] + 1], Sn[xofs[dx + 2] + 1], Sn[xofs[dx + 3] + 1], Sn[xofs[dx + 4] + 1], Sn[xofs[dx + 5] + 1], Sn[xofs[dx + 6] + 1], Sn[xofs[dx + 7] + 1]); + __m256 _Sn3 = _mm256_setr_ps(Sn[xofs[dx] + 2], Sn[xofs[dx + 1] + 2], Sn[xofs[dx + 2] + 2], Sn[xofs[dx + 3] + 2], Sn[xofs[dx + 4] + 2], Sn[xofs[dx + 5] + 2], Sn[xofs[dx + 6] + 2], Sn[xofs[dx + 7] + 2]); +#endif + + __m256 _rowsn = _mm256_mul_ps(_Sn0, _a0); + _rowsn = _mm256_comp_fmadd_ps(_Sn1, _a1, _rowsn); + _rowsn = _mm256_comp_fmadd_ps(_Sn2, _a2, _rowsn); + _rowsn = _mm256_comp_fmadd_ps(_Sn3, _a3, _rowsn); + _mm256_storeu_ps(rowsnp + dx, _rowsn); + } + + alphap += 32; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; dx < w; dx++) { int sx = xofs[dx]; const float* S1p = S1 + sx; @@ -205,7 +549,115 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sxn1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(-1)); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + __m512i _sx2 = _mm512_add_epi32(_sx, _mm512_set1_epi32(2)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512 _a1 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)), alphap, sizeof(float)); + __m512 _a2 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(2)), alphap, sizeof(float)); + __m512 _a3 = _mm512_i32gather_ps(_mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(3)), alphap, sizeof(float)); + + __m512 _S00 = _mm512_i32gather_ps(_sxn1, S0, sizeof(float)); + __m512 _S01 = _mm512_i32gather_ps(_sx, S0, sizeof(float)); + __m512 _S02 = _mm512_i32gather_ps(_sx1, S0, sizeof(float)); + __m512 _S03 = _mm512_i32gather_ps(_sx2, S0, sizeof(float)); + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + _rows0 = _mm512_fmadd_ps(_S02, _a2, _rows0); + _rows0 = _mm512_fmadd_ps(_S03, _a3, _rows0); + _mm512_storeu_ps(rows0p + dx, _rows0); + + __m512 _S10 = _mm512_i32gather_ps(_sxn1, S1, sizeof(float)); + __m512 _S11 = _mm512_i32gather_ps(_sx, S1, sizeof(float)); + __m512 _S12 = _mm512_i32gather_ps(_sx1, S1, sizeof(float)); + __m512 _S13 = _mm512_i32gather_ps(_sx2, S1, sizeof(float)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx, _rows1); + + __m512 _S20 = _mm512_i32gather_ps(_sxn1, S2, sizeof(float)); + __m512 _S21 = _mm512_i32gather_ps(_sx, S2, sizeof(float)); + __m512 _S22 = _mm512_i32gather_ps(_sx1, S2, sizeof(float)); + __m512 _S23 = _mm512_i32gather_ps(_sx2, S2, sizeof(float)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx, _rows2); + + __m512 _S30 = _mm512_i32gather_ps(_sxn1, S3, sizeof(float)); + __m512 _S31 = _mm512_i32gather_ps(_sx, S3, sizeof(float)); + __m512 _S32 = _mm512_i32gather_ps(_sx1, S3, sizeof(float)); + __m512 _S33 = _mm512_i32gather_ps(_sx2, S3, sizeof(float)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx, _rows3); + + alphap += 64; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sxn1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(-1)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + __m256i _sx2 = _mm256_add_epi32(_sx, _mm256_set1_epi32(2)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256 _a1 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(1)), sizeof(float)); + __m256 _a2 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(2)), sizeof(float)); + __m256 _a3 = _mm256_i32gather_ps(alphap, _mm256_add_epi32(_alpha_idx, _mm256_set1_epi32(3)), sizeof(float)); +#else + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[4], alphap[8], alphap[12], alphap[16], alphap[20], alphap[24], alphap[28]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[5], alphap[9], alphap[13], alphap[17], alphap[21], alphap[25], alphap[29]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[6], alphap[10], alphap[14], alphap[18], alphap[22], alphap[26], alphap[30]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[7], alphap[11], alphap[15], alphap[19], alphap[23], alphap[27], alphap[31]); +#endif + + for (int r = 0; r < 4; r++) + { + const float* Sn = (r == 0) ? S0 : ((r == 1) ? S1 : ((r == 2) ? S2 : S3)); + float* rowsnp = (r == 0) ? rows0p : ((r == 1) ? rows1p : ((r == 2) ? rows2p : rows3p)); + +#if __AVX2__ + __m256 _Sn0 = _mm256_i32gather_ps(Sn, _sxn1, sizeof(float)); + __m256 _Sn1 = _mm256_i32gather_ps(Sn, _sx, sizeof(float)); + __m256 _Sn2 = _mm256_i32gather_ps(Sn, _sx1, sizeof(float)); + __m256 _Sn3 = _mm256_i32gather_ps(Sn, _sx2, sizeof(float)); +#else + __m256 _Sn0 = _mm256_setr_ps(Sn[xofs[dx] - 1], Sn[xofs[dx + 1] - 1], Sn[xofs[dx + 2] - 1], Sn[xofs[dx + 3] - 1], Sn[xofs[dx + 4] - 1], Sn[xofs[dx + 5] - 1], Sn[xofs[dx + 6] - 1], Sn[xofs[dx + 7] - 1]); + __m256 _Sn1 = _mm256_setr_ps(Sn[xofs[dx]], Sn[xofs[dx + 1]], Sn[xofs[dx + 2]], Sn[xofs[dx + 3]], Sn[xofs[dx + 4]], Sn[xofs[dx + 5]], Sn[xofs[dx + 6]], Sn[xofs[dx + 7]]); + __m256 _Sn2 = _mm256_setr_ps(Sn[xofs[dx] + 1], Sn[xofs[dx + 1] + 1], Sn[xofs[dx + 2] + 1], Sn[xofs[dx + 3] + 1], Sn[xofs[dx + 4] + 1], Sn[xofs[dx + 5] + 1], Sn[xofs[dx + 6] + 1], Sn[xofs[dx + 7] + 1]); + __m256 _Sn3 = _mm256_setr_ps(Sn[xofs[dx] + 2], Sn[xofs[dx + 1] + 2], Sn[xofs[dx + 2] + 2], Sn[xofs[dx + 3] + 2], Sn[xofs[dx + 4] + 2], Sn[xofs[dx + 5] + 2], Sn[xofs[dx + 6] + 2], Sn[xofs[dx + 7] + 2]); +#endif + + __m256 _rowsn = _mm256_mul_ps(_Sn0, _a0); + _rowsn = _mm256_comp_fmadd_ps(_Sn1, _a1, _rowsn); + _rowsn = _mm256_comp_fmadd_ps(_Sn2, _a2, _rowsn); + _rowsn = _mm256_comp_fmadd_ps(_Sn3, _a3, _rowsn); + _mm256_storeu_ps(rowsnp + dx, _rowsn); + } + + alphap += 32; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; dx < w; dx++) { int sx = xofs[dx]; const float* S0p = S0 + sx; @@ -229,70 +681,7 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo prev_sy1 = sy; // vresize - float b0 = beta[0]; - float b1 = beta[1]; - float b2 = beta[2]; - float b3 = beta[3]; - - float* rows0p = rows0; - float* rows1p = rows1; - float* rows2p = rows2; - float* rows3p = rows3; - float* Dp = dst.row(dy); - - int dx = 0; -#if __SSE2__ -#if __AVX__ - __m256 _b0_256 = _mm256_set1_ps(b0); - __m256 _b1_256 = _mm256_set1_ps(b1); - __m256 _b2_256 = _mm256_set1_ps(b2); - __m256 _b3_256 = _mm256_set1_ps(b3); - for (; dx + 7 < w; dx += 8) - { - __m256 _rows0 = _mm256_loadu_ps(rows0p); - __m256 _rows1 = _mm256_loadu_ps(rows1p); - __m256 _rows2 = _mm256_loadu_ps(rows2p); - __m256 _rows3 = _mm256_loadu_ps(rows3p); - __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); - _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); - _Dp = _mm256_comp_fmadd_ps(_rows2, _b2_256, _Dp); - _Dp = _mm256_comp_fmadd_ps(_rows3, _b3_256, _Dp); - _mm256_storeu_ps(Dp, _Dp); - - Dp += 8; - rows0p += 8; - rows1p += 8; - rows2p += 8; - rows3p += 8; - } -#endif // __AVX__ - __m128 _b0_128 = _mm_set1_ps(b0); - __m128 _b1_128 = _mm_set1_ps(b1); - __m128 _b2_128 = _mm_set1_ps(b2); - __m128 _b3_128 = _mm_set1_ps(b3); - for (; dx + 3 < w; dx += 4) - { - __m128 _rows0 = _mm_loadu_ps(rows0p); - __m128 _rows1 = _mm_loadu_ps(rows1p); - __m128 _rows2 = _mm_loadu_ps(rows2p); - __m128 _rows3 = _mm_loadu_ps(rows3p); - __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); - _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); - _Dp = _mm_comp_fmadd_ps(_rows2, _b2_128, _Dp); - _Dp = _mm_comp_fmadd_ps(_rows3, _b3_128, _Dp); - _mm_storeu_ps(Dp, _Dp); - - Dp += 4; - rows0p += 4; - rows1p += 4; - rows2p += 4; - rows3p += 4; - } -#endif // __SSE2__ - for (; dx < w; dx++) - { - *Dp++ = *rows0p++ * b0 + *rows1p++ * b1 + *rows2p++ * b2 + *rows3p++ * b3; - } + vresize_bicubic(rows0, rows1, rows2, rows3, dst.row(dy), w, beta[0], beta[1], beta[2], beta[3]); beta += 4; } diff --git a/src/layer/x86/interp_bicubic_pack16.h b/src/layer/x86/interp_bicubic_pack16.h index 6d92497efb9..e5d363c5de4 100644 --- a/src/layer/x86/interp_bicubic_pack16.h +++ b/src/layer/x86/interp_bicubic_pack16.h @@ -240,35 +240,7 @@ static void resize_bicubic_image_pack16(const Mat& src, Mat& dst, float* alpha, prev_sy1 = sy; // vresize - __m512 _b0 = _mm512_set1_ps(beta[0]); - __m512 _b1 = _mm512_set1_ps(beta[1]); - __m512 _b2 = _mm512_set1_ps(beta[2]); - __m512 _b3 = _mm512_set1_ps(beta[3]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* rows2p = rows2; - float* rows3p = rows3; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m512 _rows0 = _mm512_load_ps(rows0p); - __m512 _rows1 = _mm512_load_ps(rows1p); - __m512 _rows2 = _mm512_load_ps(rows2p); - __m512 _rows3 = _mm512_load_ps(rows3p); - __m512 _Dp = _mm512_mul_ps(_rows0, _b0); - _Dp = _mm512_fmadd_ps(_rows1, _b1, _Dp); - _Dp = _mm512_fmadd_ps(_rows2, _b2, _Dp); - _Dp = _mm512_fmadd_ps(_rows3, _b3, _Dp); - _mm512_store_ps(Dp, _Dp); - - Dp += 16; - rows0p += 16; - rows1p += 16; - rows2p += 16; - rows3p += 16; - } + vresize_bicubic(rows0, rows1, rows2, rows3, dst.row(dy), w * 16, beta[0], beta[1], beta[2], beta[3]); beta += 4; } diff --git a/src/layer/x86/interp_bicubic_pack4.h b/src/layer/x86/interp_bicubic_pack4.h index 7d26e8b0c29..dc226e86274 100644 --- a/src/layer/x86/interp_bicubic_pack4.h +++ b/src/layer/x86/interp_bicubic_pack4.h @@ -38,7 +38,60 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i const float* alphap = alpha; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[8], alphap[8], alphap[8], alphap[8], alphap[12], alphap[12], alphap[12], alphap[12]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[9], alphap[9], alphap[9], alphap[9], alphap[13], alphap[13], alphap[13], alphap[13]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[10], alphap[10], alphap[10], alphap[10], alphap[14], alphap[14], alphap[14], alphap[14]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[11], alphap[11], alphap[11], alphap[11], alphap[15], alphap[15], alphap[15], alphap[15]); + + __m512 _S30 = combine4x4_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4), _mm_load_ps(S3 + sx2 - 4), _mm_load_ps(S3 + sx3 - 4)); + __m512 _S31 = combine4x4_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1), _mm_load_ps(S3 + sx2), _mm_load_ps(S3 + sx3)); + __m512 _S32 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4), _mm_load_ps(S3 + sx2 + 4), _mm_load_ps(S3 + sx3 + 4)); + __m512 _S33 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8), _mm_load_ps(S3 + sx2 + 8), _mm_load_ps(S3 + sx3 + 8)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 16; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m256 _S30 = combine4x2_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4)); + __m256 _S31 = combine4x2_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1)); + __m256 _S32 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4)); + __m256 _S33 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8)); + + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 8; + } +#endif // __AVX__ + for (; dx < w; dx++) { int sx = xofs[dx] * 4; const float* S3p = S3 + sx; @@ -76,7 +129,80 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i const float* alphap = alpha; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[8], alphap[8], alphap[8], alphap[8], alphap[12], alphap[12], alphap[12], alphap[12]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[9], alphap[9], alphap[9], alphap[9], alphap[13], alphap[13], alphap[13], alphap[13]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[10], alphap[10], alphap[10], alphap[10], alphap[14], alphap[14], alphap[14], alphap[14]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[11], alphap[11], alphap[11], alphap[11], alphap[15], alphap[15], alphap[15], alphap[15]); + + __m512 _S20 = combine4x4_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4), _mm_load_ps(S2 + sx2 - 4), _mm_load_ps(S2 + sx3 - 4)); + __m512 _S21 = combine4x4_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1), _mm_load_ps(S2 + sx2), _mm_load_ps(S2 + sx3)); + __m512 _S22 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4), _mm_load_ps(S2 + sx2 + 4), _mm_load_ps(S2 + sx3 + 4)); + __m512 _S23 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8), _mm_load_ps(S2 + sx2 + 8), _mm_load_ps(S2 + sx3 + 8)); + + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 4, _rows2); + + __m512 _S30 = combine4x4_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4), _mm_load_ps(S3 + sx2 - 4), _mm_load_ps(S3 + sx3 - 4)); + __m512 _S31 = combine4x4_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1), _mm_load_ps(S3 + sx2), _mm_load_ps(S3 + sx3)); + __m512 _S32 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4), _mm_load_ps(S3 + sx2 + 4), _mm_load_ps(S3 + sx3 + 4)); + __m512 _S33 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8), _mm_load_ps(S3 + sx2 + 8), _mm_load_ps(S3 + sx3 + 8)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 16; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m256 _S20 = combine4x2_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4)); + __m256 _S21 = combine4x2_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1)); + __m256 _S22 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4)); + __m256 _S23 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8)); + __m256 _S30 = combine4x2_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4)); + __m256 _S31 = combine4x2_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1)); + __m256 _S32 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4)); + __m256 _S33 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8)); + + __m256 _rows2 = _mm256_mul_ps(_S20, _a0); + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows2 = _mm256_comp_fmadd_ps(_S21, _a1, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows2 = _mm256_comp_fmadd_ps(_S22, _a2, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows2 = _mm256_comp_fmadd_ps(_S23, _a3, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows2p + dx * 4, _rows2); + _mm256_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 8; + } +#endif // __AVX__ + for (; dx < w; dx++) { int sx = xofs[dx] * 4; const float* S2p = S2 + sx; @@ -127,7 +253,97 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[8], alphap[8], alphap[8], alphap[8], alphap[12], alphap[12], alphap[12], alphap[12]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[9], alphap[9], alphap[9], alphap[9], alphap[13], alphap[13], alphap[13], alphap[13]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[10], alphap[10], alphap[10], alphap[10], alphap[14], alphap[14], alphap[14], alphap[14]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[11], alphap[11], alphap[11], alphap[11], alphap[15], alphap[15], alphap[15], alphap[15]); + + __m512 _S10 = combine4x4_ps(_mm_load_ps(S1 + sx0 - 4), _mm_load_ps(S1 + sx1 - 4), _mm_load_ps(S1 + sx2 - 4), _mm_load_ps(S1 + sx3 - 4)); + __m512 _S11 = combine4x4_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1), _mm_load_ps(S1 + sx2), _mm_load_ps(S1 + sx3)); + __m512 _S12 = combine4x4_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4), _mm_load_ps(S1 + sx2 + 4), _mm_load_ps(S1 + sx3 + 4)); + __m512 _S13 = combine4x4_ps(_mm_load_ps(S1 + sx0 + 8), _mm_load_ps(S1 + sx1 + 8), _mm_load_ps(S1 + sx2 + 8), _mm_load_ps(S1 + sx3 + 8)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx * 4, _rows1); + + __m512 _S20 = combine4x4_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4), _mm_load_ps(S2 + sx2 - 4), _mm_load_ps(S2 + sx3 - 4)); + __m512 _S21 = combine4x4_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1), _mm_load_ps(S2 + sx2), _mm_load_ps(S2 + sx3)); + __m512 _S22 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4), _mm_load_ps(S2 + sx2 + 4), _mm_load_ps(S2 + sx3 + 4)); + __m512 _S23 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8), _mm_load_ps(S2 + sx2 + 8), _mm_load_ps(S2 + sx3 + 8)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 4, _rows2); + + __m512 _S30 = combine4x4_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4), _mm_load_ps(S3 + sx2 - 4), _mm_load_ps(S3 + sx3 - 4)); + __m512 _S31 = combine4x4_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1), _mm_load_ps(S3 + sx2), _mm_load_ps(S3 + sx3)); + __m512 _S32 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4), _mm_load_ps(S3 + sx2 + 4), _mm_load_ps(S3 + sx3 + 4)); + __m512 _S33 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8), _mm_load_ps(S3 + sx2 + 8), _mm_load_ps(S3 + sx3 + 8)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 16; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m256 _S10 = combine4x2_ps(_mm_load_ps(S1 + sx0 - 4), _mm_load_ps(S1 + sx1 - 4)); + __m256 _S11 = combine4x2_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1)); + __m256 _S12 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4)); + __m256 _S13 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 8), _mm_load_ps(S1 + sx1 + 8)); + __m256 _S20 = combine4x2_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4)); + __m256 _S21 = combine4x2_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1)); + __m256 _S22 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4)); + __m256 _S23 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8)); + __m256 _S30 = combine4x2_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4)); + __m256 _S31 = combine4x2_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1)); + __m256 _S32 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4)); + __m256 _S33 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8)); + + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + __m256 _rows2 = _mm256_mul_ps(_S20, _a0); + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S21, _a1, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows1 = _mm256_comp_fmadd_ps(_S12, _a2, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S22, _a2, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows1 = _mm256_comp_fmadd_ps(_S13, _a3, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S23, _a3, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows1p + dx * 4, _rows1); + _mm256_storeu_ps(rows2p + dx * 4, _rows2); + _mm256_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 8; + } +#endif // __AVX__ + for (; dx < w; dx++) { int sx = xofs[dx] * 4; const float* S1p = S1 + sx; @@ -183,7 +399,116 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[8], alphap[8], alphap[8], alphap[8], alphap[12], alphap[12], alphap[12], alphap[12]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[9], alphap[9], alphap[9], alphap[9], alphap[13], alphap[13], alphap[13], alphap[13]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[10], alphap[10], alphap[10], alphap[10], alphap[14], alphap[14], alphap[14], alphap[14]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[11], alphap[11], alphap[11], alphap[11], alphap[15], alphap[15], alphap[15], alphap[15]); + + __m512 _S00 = combine4x4_ps(_mm_load_ps(S0 + sx0 - 4), _mm_load_ps(S0 + sx1 - 4), _mm_load_ps(S0 + sx2 - 4), _mm_load_ps(S0 + sx3 - 4)); + __m512 _S01 = combine4x4_ps(_mm_load_ps(S0 + sx0), _mm_load_ps(S0 + sx1), _mm_load_ps(S0 + sx2), _mm_load_ps(S0 + sx3)); + __m512 _S02 = combine4x4_ps(_mm_load_ps(S0 + sx0 + 4), _mm_load_ps(S0 + sx1 + 4), _mm_load_ps(S0 + sx2 + 4), _mm_load_ps(S0 + sx3 + 4)); + __m512 _S03 = combine4x4_ps(_mm_load_ps(S0 + sx0 + 8), _mm_load_ps(S0 + sx1 + 8), _mm_load_ps(S0 + sx2 + 8), _mm_load_ps(S0 + sx3 + 8)); + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + _rows0 = _mm512_fmadd_ps(_S02, _a2, _rows0); + _rows0 = _mm512_fmadd_ps(_S03, _a3, _rows0); + _mm512_storeu_ps(rows0p + dx * 4, _rows0); + + __m512 _S10 = combine4x4_ps(_mm_load_ps(S1 + sx0 - 4), _mm_load_ps(S1 + sx1 - 4), _mm_load_ps(S1 + sx2 - 4), _mm_load_ps(S1 + sx3 - 4)); + __m512 _S11 = combine4x4_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1), _mm_load_ps(S1 + sx2), _mm_load_ps(S1 + sx3)); + __m512 _S12 = combine4x4_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4), _mm_load_ps(S1 + sx2 + 4), _mm_load_ps(S1 + sx3 + 4)); + __m512 _S13 = combine4x4_ps(_mm_load_ps(S1 + sx0 + 8), _mm_load_ps(S1 + sx1 + 8), _mm_load_ps(S1 + sx2 + 8), _mm_load_ps(S1 + sx3 + 8)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx * 4, _rows1); + + __m512 _S20 = combine4x4_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4), _mm_load_ps(S2 + sx2 - 4), _mm_load_ps(S2 + sx3 - 4)); + __m512 _S21 = combine4x4_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1), _mm_load_ps(S2 + sx2), _mm_load_ps(S2 + sx3)); + __m512 _S22 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4), _mm_load_ps(S2 + sx2 + 4), _mm_load_ps(S2 + sx3 + 4)); + __m512 _S23 = combine4x4_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8), _mm_load_ps(S2 + sx2 + 8), _mm_load_ps(S2 + sx3 + 8)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 4, _rows2); + + __m512 _S30 = combine4x4_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4), _mm_load_ps(S3 + sx2 - 4), _mm_load_ps(S3 + sx3 - 4)); + __m512 _S31 = combine4x4_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1), _mm_load_ps(S3 + sx2), _mm_load_ps(S3 + sx3)); + __m512 _S32 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4), _mm_load_ps(S3 + sx2 + 4), _mm_load_ps(S3 + sx3 + 4)); + __m512 _S33 = combine4x4_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8), _mm_load_ps(S3 + sx2 + 8), _mm_load_ps(S3 + sx3 + 8)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 16; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5]); + __m256 _a2 = _mm256_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6]); + __m256 _a3 = _mm256_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m256 _S00 = combine4x2_ps(_mm_load_ps(S0 + sx0 - 4), _mm_load_ps(S0 + sx1 - 4)); + __m256 _S01 = combine4x2_ps(_mm_load_ps(S0 + sx0), _mm_load_ps(S0 + sx1)); + __m256 _S02 = combine4x2_ps(_mm_load_ps(S0 + sx0 + 4), _mm_load_ps(S0 + sx1 + 4)); + __m256 _S03 = combine4x2_ps(_mm_load_ps(S0 + sx0 + 8), _mm_load_ps(S0 + sx1 + 8)); + __m256 _S10 = combine4x2_ps(_mm_load_ps(S1 + sx0 - 4), _mm_load_ps(S1 + sx1 - 4)); + __m256 _S11 = combine4x2_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1)); + __m256 _S12 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4)); + __m256 _S13 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 8), _mm_load_ps(S1 + sx1 + 8)); + __m256 _S20 = combine4x2_ps(_mm_load_ps(S2 + sx0 - 4), _mm_load_ps(S2 + sx1 - 4)); + __m256 _S21 = combine4x2_ps(_mm_load_ps(S2 + sx0), _mm_load_ps(S2 + sx1)); + __m256 _S22 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 4), _mm_load_ps(S2 + sx1 + 4)); + __m256 _S23 = combine4x2_ps(_mm_load_ps(S2 + sx0 + 8), _mm_load_ps(S2 + sx1 + 8)); + __m256 _S30 = combine4x2_ps(_mm_load_ps(S3 + sx0 - 4), _mm_load_ps(S3 + sx1 - 4)); + __m256 _S31 = combine4x2_ps(_mm_load_ps(S3 + sx0), _mm_load_ps(S3 + sx1)); + __m256 _S32 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 4), _mm_load_ps(S3 + sx1 + 4)); + __m256 _S33 = combine4x2_ps(_mm_load_ps(S3 + sx0 + 8), _mm_load_ps(S3 + sx1 + 8)); + + __m256 _rows0 = _mm256_mul_ps(_S00, _a0); + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + __m256 _rows2 = _mm256_mul_ps(_S20, _a0); + __m256 _rows3 = _mm256_mul_ps(_S30, _a0); + _rows0 = _mm256_comp_fmadd_ps(_S01, _a1, _rows0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S21, _a1, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S31, _a1, _rows3); + _rows0 = _mm256_comp_fmadd_ps(_S02, _a2, _rows0); + _rows1 = _mm256_comp_fmadd_ps(_S12, _a2, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S22, _a2, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S32, _a2, _rows3); + _rows0 = _mm256_comp_fmadd_ps(_S03, _a3, _rows0); + _rows1 = _mm256_comp_fmadd_ps(_S13, _a3, _rows1); + _rows2 = _mm256_comp_fmadd_ps(_S23, _a3, _rows2); + _rows3 = _mm256_comp_fmadd_ps(_S33, _a3, _rows3); + _mm256_storeu_ps(rows0p + dx * 4, _rows0); + _mm256_storeu_ps(rows1p + dx * 4, _rows1); + _mm256_storeu_ps(rows2p + dx * 4, _rows2); + _mm256_storeu_ps(rows3p + dx * 4, _rows3); + + alphap += 8; + } +#endif // __AVX__ + for (; dx < w; dx++) { int sx = xofs[dx] * 4; const float* S0p = S0 + sx; @@ -240,35 +565,7 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i prev_sy1 = sy; // vresize - __m128 _b0 = _mm_set1_ps(beta[0]); - __m128 _b1 = _mm_set1_ps(beta[1]); - __m128 _b2 = _mm_set1_ps(beta[2]); - __m128 _b3 = _mm_set1_ps(beta[3]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* rows2p = rows2; - float* rows3p = rows3; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m128 _rows0 = _mm_load_ps(rows0p); - __m128 _rows1 = _mm_load_ps(rows1p); - __m128 _rows2 = _mm_load_ps(rows2p); - __m128 _rows3 = _mm_load_ps(rows3p); - __m128 _Dp = _mm_mul_ps(_rows0, _b0); - _Dp = _mm_comp_fmadd_ps(_rows1, _b1, _Dp); - _Dp = _mm_comp_fmadd_ps(_rows2, _b2, _Dp); - _Dp = _mm_comp_fmadd_ps(_rows3, _b3, _Dp); - _mm_store_ps(Dp, _Dp); - - Dp += 4; - rows0p += 4; - rows1p += 4; - rows2p += 4; - rows3p += 4; - } + vresize_bicubic(rows0, rows1, rows2, rows3, dst.row(dy), w * 4, beta[0], beta[1], beta[2], beta[3]); beta += 4; } diff --git a/src/layer/x86/interp_bicubic_pack8.h b/src/layer/x86/interp_bicubic_pack8.h index e85975368fa..a19ff4dea34 100644 --- a/src/layer/x86/interp_bicubic_pack8.h +++ b/src/layer/x86/interp_bicubic_pack8.h @@ -38,7 +38,33 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i const float* alphap = alpha; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m512 _S30 = combine8x2_ps(_mm256_load_ps(S3 + sx0 - 8), _mm256_load_ps(S3 + sx1 - 8)); + __m512 _S31 = combine8x2_ps(_mm256_load_ps(S3 + sx0), _mm256_load_ps(S3 + sx1)); + __m512 _S32 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 8), _mm256_load_ps(S3 + sx1 + 8)); + __m512 _S33 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 16), _mm256_load_ps(S3 + sx1 + 16)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 8, _rows3); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx < w; dx++) { int sx = xofs[dx] * 8; const float* S3p = S3 + sx; @@ -76,7 +102,44 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i const float* alphap = alpha; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m512 _S20 = combine8x2_ps(_mm256_load_ps(S2 + sx0 - 8), _mm256_load_ps(S2 + sx1 - 8)); + __m512 _S21 = combine8x2_ps(_mm256_load_ps(S2 + sx0), _mm256_load_ps(S2 + sx1)); + __m512 _S22 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 8), _mm256_load_ps(S2 + sx1 + 8)); + __m512 _S23 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 16), _mm256_load_ps(S2 + sx1 + 16)); + + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 8, _rows2); + + __m512 _S30 = combine8x2_ps(_mm256_load_ps(S3 + sx0 - 8), _mm256_load_ps(S3 + sx1 - 8)); + __m512 _S31 = combine8x2_ps(_mm256_load_ps(S3 + sx0), _mm256_load_ps(S3 + sx1)); + __m512 _S32 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 8), _mm256_load_ps(S3 + sx1 + 8)); + __m512 _S33 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 16), _mm256_load_ps(S3 + sx1 + 16)); + + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 8, _rows3); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx < w; dx++) { int sx = xofs[dx] * 8; const float* S2p = S2 + sx; @@ -127,7 +190,52 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m512 _S10 = combine8x2_ps(_mm256_load_ps(S1 + sx0 - 8), _mm256_load_ps(S1 + sx1 - 8)); + __m512 _S11 = combine8x2_ps(_mm256_load_ps(S1 + sx0), _mm256_load_ps(S1 + sx1)); + __m512 _S12 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 8), _mm256_load_ps(S1 + sx1 + 8)); + __m512 _S13 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 16), _mm256_load_ps(S1 + sx1 + 16)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx * 8, _rows1); + + __m512 _S20 = combine8x2_ps(_mm256_load_ps(S2 + sx0 - 8), _mm256_load_ps(S2 + sx1 - 8)); + __m512 _S21 = combine8x2_ps(_mm256_load_ps(S2 + sx0), _mm256_load_ps(S2 + sx1)); + __m512 _S22 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 8), _mm256_load_ps(S2 + sx1 + 8)); + __m512 _S23 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 16), _mm256_load_ps(S2 + sx1 + 16)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 8, _rows2); + + __m512 _S30 = combine8x2_ps(_mm256_load_ps(S3 + sx0 - 8), _mm256_load_ps(S3 + sx1 - 8)); + __m512 _S31 = combine8x2_ps(_mm256_load_ps(S3 + sx0), _mm256_load_ps(S3 + sx1)); + __m512 _S32 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 8), _mm256_load_ps(S3 + sx1 + 8)); + __m512 _S33 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 16), _mm256_load_ps(S3 + sx1 + 16)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 8, _rows3); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx < w; dx++) { int sx = xofs[dx] * 8; const float* S1p = S1 + sx; @@ -183,7 +291,62 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i float* rows1p = rows1; float* rows2p = rows2; float* rows3p = rows3; - for (int dx = 0; dx < w; dx++) + int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4], alphap[4]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5], alphap[5]); + __m512 _a2 = _mm512_setr_ps(alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a3 = _mm512_setr_ps(alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m512 _S00 = combine8x2_ps(_mm256_load_ps(S0 + sx0 - 8), _mm256_load_ps(S0 + sx1 - 8)); + __m512 _S01 = combine8x2_ps(_mm256_load_ps(S0 + sx0), _mm256_load_ps(S0 + sx1)); + __m512 _S02 = combine8x2_ps(_mm256_load_ps(S0 + sx0 + 8), _mm256_load_ps(S0 + sx1 + 8)); + __m512 _S03 = combine8x2_ps(_mm256_load_ps(S0 + sx0 + 16), _mm256_load_ps(S0 + sx1 + 16)); + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + _rows0 = _mm512_fmadd_ps(_S02, _a2, _rows0); + _rows0 = _mm512_fmadd_ps(_S03, _a3, _rows0); + _mm512_storeu_ps(rows0p + dx * 8, _rows0); + + __m512 _S10 = combine8x2_ps(_mm256_load_ps(S1 + sx0 - 8), _mm256_load_ps(S1 + sx1 - 8)); + __m512 _S11 = combine8x2_ps(_mm256_load_ps(S1 + sx0), _mm256_load_ps(S1 + sx1)); + __m512 _S12 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 8), _mm256_load_ps(S1 + sx1 + 8)); + __m512 _S13 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 16), _mm256_load_ps(S1 + sx1 + 16)); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _rows1 = _mm512_fmadd_ps(_S12, _a2, _rows1); + _rows1 = _mm512_fmadd_ps(_S13, _a3, _rows1); + _mm512_storeu_ps(rows1p + dx * 8, _rows1); + + __m512 _S20 = combine8x2_ps(_mm256_load_ps(S2 + sx0 - 8), _mm256_load_ps(S2 + sx1 - 8)); + __m512 _S21 = combine8x2_ps(_mm256_load_ps(S2 + sx0), _mm256_load_ps(S2 + sx1)); + __m512 _S22 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 8), _mm256_load_ps(S2 + sx1 + 8)); + __m512 _S23 = combine8x2_ps(_mm256_load_ps(S2 + sx0 + 16), _mm256_load_ps(S2 + sx1 + 16)); + __m512 _rows2 = _mm512_mul_ps(_S20, _a0); + _rows2 = _mm512_fmadd_ps(_S21, _a1, _rows2); + _rows2 = _mm512_fmadd_ps(_S22, _a2, _rows2); + _rows2 = _mm512_fmadd_ps(_S23, _a3, _rows2); + _mm512_storeu_ps(rows2p + dx * 8, _rows2); + + __m512 _S30 = combine8x2_ps(_mm256_load_ps(S3 + sx0 - 8), _mm256_load_ps(S3 + sx1 - 8)); + __m512 _S31 = combine8x2_ps(_mm256_load_ps(S3 + sx0), _mm256_load_ps(S3 + sx1)); + __m512 _S32 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 8), _mm256_load_ps(S3 + sx1 + 8)); + __m512 _S33 = combine8x2_ps(_mm256_load_ps(S3 + sx0 + 16), _mm256_load_ps(S3 + sx1 + 16)); + __m512 _rows3 = _mm512_mul_ps(_S30, _a0); + _rows3 = _mm512_fmadd_ps(_S31, _a1, _rows3); + _rows3 = _mm512_fmadd_ps(_S32, _a2, _rows3); + _rows3 = _mm512_fmadd_ps(_S33, _a3, _rows3); + _mm512_storeu_ps(rows3p + dx * 8, _rows3); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx < w; dx++) { int sx = xofs[dx] * 8; const float* S0p = S0 + sx; @@ -240,35 +403,7 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i prev_sy1 = sy; // vresize - __m256 _b0 = _mm256_set1_ps(beta[0]); - __m256 _b1 = _mm256_set1_ps(beta[1]); - __m256 _b2 = _mm256_set1_ps(beta[2]); - __m256 _b3 = _mm256_set1_ps(beta[3]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* rows2p = rows2; - float* rows3p = rows3; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m256 _rows0 = _mm256_load_ps(rows0p); - __m256 _rows1 = _mm256_load_ps(rows1p); - __m256 _rows2 = _mm256_load_ps(rows2p); - __m256 _rows3 = _mm256_load_ps(rows3p); - __m256 _Dp = _mm256_mul_ps(_rows0, _b0); - _Dp = _mm256_comp_fmadd_ps(_rows1, _b1, _Dp); - _Dp = _mm256_comp_fmadd_ps(_rows2, _b2, _Dp); - _Dp = _mm256_comp_fmadd_ps(_rows3, _b3, _Dp); - _mm256_store_ps(Dp, _Dp); - - Dp += 8; - rows0p += 8; - rows1p += 8; - rows2p += 8; - rows3p += 8; - } + vresize_bicubic(rows0, rows1, rows2, rows3, dst.row(dy), w * 8, beta[0], beta[1], beta[2], beta[3]); beta += 4; } diff --git a/src/layer/x86/interp_bilinear.h b/src/layer/x86/interp_bilinear.h index 7a3e84dcc9c..82c65a7956d 100644 --- a/src/layer/x86/interp_bilinear.h +++ b/src/layer/x86/interp_bilinear.h @@ -1,6 +1,10 @@ // Copyright 2020 Tencent // SPDX-License-Identifier: BSD-3-Clause +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void resize_bilinear_image_avx2(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs); +#endif + static void linear_coeffs(int w, int outw, int* xofs, float* alpha, int align_corner) { double scale = (double)w / outw; @@ -38,8 +42,61 @@ static void linear_coeffs(int w, int outw, int* xofs, float* alpha, int align_co } } +static void vresize_bilinear(const float* rows0, const float* rows1, float* Dp, int n, float b0, float b1) +{ + int nn = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _b0_512 = _mm512_set1_ps(b0); + __m512 _b1_512 = _mm512_set1_ps(b1); + for (; nn + 15 < n; nn += 16) + { + __m512 _rows0 = _mm512_loadu_ps(rows0 + nn); + __m512 _rows1 = _mm512_loadu_ps(rows1 + nn); + __m512 _Dp = _mm512_mul_ps(_rows0, _b0_512); + _Dp = _mm512_fmadd_ps(_rows1, _b1_512, _Dp); + _mm512_storeu_ps(Dp + nn, _Dp); + } +#endif // __AVX512F__ + __m256 _b0_256 = _mm256_set1_ps(b0); + __m256 _b1_256 = _mm256_set1_ps(b1); + for (; nn + 7 < n; nn += 8) + { + __m256 _rows0 = _mm256_loadu_ps(rows0 + nn); + __m256 _rows1 = _mm256_loadu_ps(rows1 + nn); + __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); + _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); + _mm256_storeu_ps(Dp + nn, _Dp); + } +#endif // __AVX__ + __m128 _b0_128 = _mm_set1_ps(b0); + __m128 _b1_128 = _mm_set1_ps(b1); + for (; nn + 3 < n; nn += 4) + { + __m128 _rows0 = _mm_loadu_ps(rows0 + nn); + __m128 _rows1 = _mm_loadu_ps(rows1 + nn); + __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); + _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); + _mm_storeu_ps(Dp + nn, _Dp); + } +#endif // __SSE2__ + for (; nn < n; nn++) + { + Dp[nn] = rows0[nn] * b0 + rows1[nn] * b1; + } +} + static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs) { +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + resize_bilinear_image_avx2(src, dst, alpha, xofs, beta, yofs); + return; + } +#endif + int w = dst.w; int h = dst.h; @@ -70,6 +127,75 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x const float* alphap = alpha; float* rows1p = rows1; int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + + __m512 _S10 = _mm512_i32gather_ps(_sx, S1, sizeof(float)); + __m512 _S11 = _mm512_i32gather_ps(_sx1, S1, sizeof(float)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512i _alpha_idx1 = _mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)); + __m512 _a1 = _mm512_i32gather_ps(_alpha_idx1, alphap, sizeof(float)); + + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows1p + dx, _rows1); + + alphap += 32; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + + __m256 _S10 = _mm256_i32gather_ps(S1, _sx, sizeof(float)); + __m256 _S11 = _mm256_i32gather_ps(S1, _sx1, sizeof(float)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256i _alpha_idx1 = _mm256_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15); + __m256 _a1 = _mm256_i32gather_ps(alphap, _alpha_idx1, sizeof(float)); +#else + __m256 _S10 = _mm256_setr_ps(S1[xofs[dx]], S1[xofs[dx + 1]], S1[xofs[dx + 2]], S1[xofs[dx + 3]], S1[xofs[dx + 4]], S1[xofs[dx + 5]], S1[xofs[dx + 6]], S1[xofs[dx + 7]]); + __m256 _S11 = _mm256_setr_ps(S1[xofs[dx] + 1], S1[xofs[dx + 1] + 1], S1[xofs[dx + 2] + 1], S1[xofs[dx + 3] + 1], S1[xofs[dx + 4] + 1], S1[xofs[dx + 5] + 1], S1[xofs[dx + 6] + 1], S1[xofs[dx + 7] + 1]); + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[2], alphap[4], alphap[6], alphap[8], alphap[10], alphap[12], alphap[14]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[3], alphap[5], alphap[7], alphap[9], alphap[11], alphap[13], alphap[15]); +#endif + + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _mm256_storeu_ps(rows1p + dx, _rows1); + + alphap += 16; + } +#endif // __AVX__ + for (; dx + 3 < w; dx += 4) + { + __m128 _S10 = _mm_setr_ps(S1[xofs[dx]], S1[xofs[dx + 1]], S1[xofs[dx + 2]], S1[xofs[dx + 3]]); + __m128 _S11 = _mm_setr_ps(S1[xofs[dx] + 1], S1[xofs[dx + 1] + 1], S1[xofs[dx + 2] + 1], S1[xofs[dx + 3] + 1]); + + __m128 _a01 = _mm_loadu_ps(alphap); + __m128 _a23 = _mm_loadu_ps(alphap + 4); + + __m128 _a0 = _mm_shuffle_ps(_a01, _a23, _MM_SHUFFLE(2, 0, 2, 0)); + __m128 _a1 = _mm_shuffle_ps(_a01, _a23, _MM_SHUFFLE(3, 1, 3, 1)); + + __m128 _rows1 = _mm_mul_ps(_S10, _a0); + _rows1 = _mm_comp_fmadd_ps(_S11, _a1, _rows1); + _mm_storeu_ps(rows1p + dx, _rows1); + + alphap += 8; + } +#endif // __SSE2__ for (; dx < w; dx++) { int sx = xofs[dx]; @@ -92,6 +218,92 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x float* rows0p = rows0; float* rows1p = rows1; int dx = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; dx + 15 < w; dx += 16) + { + __m512i _sx = _mm512_loadu_si512(xofs + dx); + __m512i _sx1 = _mm512_add_epi32(_sx, _mm512_set1_epi32(1)); + + __m512 _S00 = _mm512_i32gather_ps(_sx, S0, sizeof(float)); + __m512 _S01 = _mm512_i32gather_ps(_sx1, S0, sizeof(float)); + __m512 _S10 = _mm512_i32gather_ps(_sx, S1, sizeof(float)); + __m512 _S11 = _mm512_i32gather_ps(_sx1, S1, sizeof(float)); + + __m512i _alpha_idx = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30); + __m512 _a0 = _mm512_i32gather_ps(_alpha_idx, alphap, sizeof(float)); + __m512i _alpha_idx1 = _mm512_add_epi32(_alpha_idx, _mm512_set1_epi32(1)); + __m512 _a1 = _mm512_i32gather_ps(_alpha_idx1, alphap, sizeof(float)); + + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows0p + dx, _rows0); + _mm512_storeu_ps(rows1p + dx, _rows1); + + alphap += 32; + } +#endif // __AVX512F__ + for (; dx + 7 < w; dx += 8) + { +#if __AVX2__ + __m256i _sx = _mm256_loadu_si256((const __m256i*)(xofs + dx)); + __m256i _sx1 = _mm256_add_epi32(_sx, _mm256_set1_epi32(1)); + + __m256 _S00 = _mm256_i32gather_ps(S0, _sx, sizeof(float)); + __m256 _S01 = _mm256_i32gather_ps(S0, _sx1, sizeof(float)); + __m256 _S10 = _mm256_i32gather_ps(S1, _sx, sizeof(float)); + __m256 _S11 = _mm256_i32gather_ps(S1, _sx1, sizeof(float)); + + __m256i _alpha_idx = _mm256_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14); + __m256 _a0 = _mm256_i32gather_ps(alphap, _alpha_idx, sizeof(float)); + __m256i _alpha_idx1 = _mm256_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15); + __m256 _a1 = _mm256_i32gather_ps(alphap, _alpha_idx1, sizeof(float)); +#else + __m256 _S00 = _mm256_setr_ps(S0[xofs[dx]], S0[xofs[dx + 1]], S0[xofs[dx + 2]], S0[xofs[dx + 3]], S0[xofs[dx + 4]], S0[xofs[dx + 5]], S0[xofs[dx + 6]], S0[xofs[dx + 7]]); + __m256 _S01 = _mm256_setr_ps(S0[xofs[dx] + 1], S0[xofs[dx + 1] + 1], S0[xofs[dx + 2] + 1], S0[xofs[dx + 3] + 1], S0[xofs[dx + 4] + 1], S0[xofs[dx + 5] + 1], S0[xofs[dx + 6] + 1], S0[xofs[dx + 7] + 1]); + __m256 _S10 = _mm256_setr_ps(S1[xofs[dx]], S1[xofs[dx + 1]], S1[xofs[dx + 2]], S1[xofs[dx + 3]], S1[xofs[dx + 4]], S1[xofs[dx + 5]], S1[xofs[dx + 6]], S1[xofs[dx + 7]]); + __m256 _S11 = _mm256_setr_ps(S1[xofs[dx] + 1], S1[xofs[dx + 1] + 1], S1[xofs[dx + 2] + 1], S1[xofs[dx + 3] + 1], S1[xofs[dx + 4] + 1], S1[xofs[dx + 5] + 1], S1[xofs[dx + 6] + 1], S1[xofs[dx + 7] + 1]); + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[2], alphap[4], alphap[6], alphap[8], alphap[10], alphap[12], alphap[14]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[3], alphap[5], alphap[7], alphap[9], alphap[11], alphap[13], alphap[15]); +#endif + + __m256 _rows0 = _mm256_mul_ps(_S00, _a0); + _rows0 = _mm256_comp_fmadd_ps(_S01, _a1, _rows0); + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _mm256_storeu_ps(rows0p + dx, _rows0); + _mm256_storeu_ps(rows1p + dx, _rows1); + + alphap += 16; + } +#endif // __AVX__ + for (; dx + 3 < w; dx += 4) + { + __m128 _S00 = _mm_setr_ps(S0[xofs[dx]], S0[xofs[dx + 1]], S0[xofs[dx + 2]], S0[xofs[dx + 3]]); + __m128 _S01 = _mm_setr_ps(S0[xofs[dx] + 1], S0[xofs[dx + 1] + 1], S0[xofs[dx + 2] + 1], S0[xofs[dx + 3] + 1]); + __m128 _S10 = _mm_setr_ps(S1[xofs[dx]], S1[xofs[dx + 1]], S1[xofs[dx + 2]], S1[xofs[dx + 3]]); + __m128 _S11 = _mm_setr_ps(S1[xofs[dx] + 1], S1[xofs[dx + 1] + 1], S1[xofs[dx + 2] + 1], S1[xofs[dx + 3] + 1]); + + __m128 _a01 = _mm_loadu_ps(alphap); + __m128 _a23 = _mm_loadu_ps(alphap + 4); + + __m128 _a0 = _mm_shuffle_ps(_a01, _a23, _MM_SHUFFLE(2, 0, 2, 0)); + __m128 _a1 = _mm_shuffle_ps(_a01, _a23, _MM_SHUFFLE(3, 1, 3, 1)); + + __m128 _rows0 = _mm_mul_ps(_S00, _a0); + _rows0 = _mm_comp_fmadd_ps(_S01, _a1, _rows0); + __m128 _rows1 = _mm_mul_ps(_S10, _a0); + _rows1 = _mm_comp_fmadd_ps(_S11, _a1, _rows1); + _mm_storeu_ps(rows0p + dx, _rows0); + _mm_storeu_ps(rows1p + dx, _rows1); + + alphap += 8; + } +#endif // __SSE2__ for (; dx < w; dx++) { int sx = xofs[dx]; @@ -110,50 +322,7 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x prev_sy1 = sy; // vresize - float b0 = beta[0]; - float b1 = beta[1]; - - float* rows0p = rows0; - float* rows1p = rows1; - float* Dp = dst.row(dy); - - int dx = 0; -#if __SSE2__ -#if __AVX__ - __m256 _b0_256 = _mm256_set1_ps(b0); - __m256 _b1_256 = _mm256_set1_ps(b1); - for (; dx + 7 < w; dx += 8) - { - __m256 _rows0 = _mm256_loadu_ps(rows0p); - __m256 _rows1 = _mm256_loadu_ps(rows1p); - __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); - _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); - _mm256_storeu_ps(Dp, _Dp); - - Dp += 8; - rows0p += 8; - rows1p += 8; - } -#endif // __AVX__ - __m128 _b0_128 = _mm_set1_ps(b0); - __m128 _b1_128 = _mm_set1_ps(b1); - for (; dx + 3 < w; dx += 4) - { - __m128 _rows0 = _mm_loadu_ps(rows0p); - __m128 _rows1 = _mm_loadu_ps(rows1p); - __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); - _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); - _mm_storeu_ps(Dp, _Dp); - - Dp += 4; - rows0p += 4; - rows1p += 4; - } -#endif // __SSE2__ - for (; dx < w; dx++) - { - *Dp++ = *rows0p++ * b0 + *rows1p++ * b1; - } + vresize_bilinear(rows0, rows1, dst.row(dy), w, beta[0], beta[1]); beta += 2; } diff --git a/src/layer/x86/interp_bilinear_pack16.h b/src/layer/x86/interp_bilinear_pack16.h index 3b1a3b9bd53..78ce62d11f2 100644 --- a/src/layer/x86/interp_bilinear_pack16.h +++ b/src/layer/x86/interp_bilinear_pack16.h @@ -87,25 +87,7 @@ static void resize_bilinear_image_pack16(const Mat& src, Mat& dst, float* alpha, prev_sy1 = sy; // vresize - __m512 _b0 = _mm512_set1_ps(beta[0]); - __m512 _b1 = _mm512_set1_ps(beta[1]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m512 _rows0 = _mm512_load_ps(rows0p); - __m512 _rows1 = _mm512_load_ps(rows1p); - __m512 _Dp = _mm512_mul_ps(_rows0, _b0); - _Dp = _mm512_fmadd_ps(_rows1, _b1, _Dp); - _mm512_store_ps(Dp, _Dp); - - Dp += 16; - rows0p += 16; - rows1p += 16; - } + vresize_bilinear(rows0, rows1, dst.row(dy), w * 16, beta[0], beta[1]); beta += 2; } diff --git a/src/layer/x86/interp_bilinear_pack4.h b/src/layer/x86/interp_bilinear_pack4.h index 69eebf1d18c..3fcf04129cd 100644 --- a/src/layer/x86/interp_bilinear_pack4.h +++ b/src/layer/x86/interp_bilinear_pack4.h @@ -33,6 +33,55 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, const float* alphap = alpha; float* rows1p = rows1; int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2], alphap[4], alphap[4], alphap[4], alphap[4], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3], alphap[5], alphap[5], alphap[5], alphap[5], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m128 _S10_0 = _mm_load_ps(S1 + sx0); + __m128 _S10_1 = _mm_load_ps(S1 + sx1); + __m128 _S10_2 = _mm_load_ps(S1 + sx2); + __m128 _S10_3 = _mm_load_ps(S1 + sx3); + __m128 _S11_0 = _mm_load_ps(S1 + sx0 + 4); + __m128 _S11_1 = _mm_load_ps(S1 + sx1 + 4); + __m128 _S11_2 = _mm_load_ps(S1 + sx2 + 4); + __m128 _S11_3 = _mm_load_ps(S1 + sx3 + 4); + + __m512 _S10 = combine4x4_ps(_S10_0, _S10_1, _S10_2, _S10_3); + __m512 _S11 = combine4x4_ps(_S11_0, _S11_1, _S11_2, _S11_3); + + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows1p + dx * 4, _rows1); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3]); + + __m256 _S10 = combine4x2_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1)); + __m256 _S11 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4)); + + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _mm256_storeu_ps(rows1p + dx * 4, _rows1); + + alphap += 4; + } +#endif // __AVX__ for (; dx < w; dx++) { int sx = xofs[dx] * 4; @@ -60,6 +109,76 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, float* rows0p = rows0; float* rows1p = rows1; int dx = 0; +#if __AVX__ +#if __AVX512F__ + for (; dx + 3 < w; dx += 4) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + int sx2 = xofs[dx + 2] * 4; + int sx3 = xofs[dx + 3] * 4; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2], alphap[4], alphap[4], alphap[4], alphap[4], alphap[6], alphap[6], alphap[6], alphap[6]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3], alphap[5], alphap[5], alphap[5], alphap[5], alphap[7], alphap[7], alphap[7], alphap[7]); + + __m128 _S00_0 = _mm_load_ps(S0 + sx0); + __m128 _S00_1 = _mm_load_ps(S0 + sx1); + __m128 _S00_2 = _mm_load_ps(S0 + sx2); + __m128 _S00_3 = _mm_load_ps(S0 + sx3); + __m128 _S01_0 = _mm_load_ps(S0 + sx0 + 4); + __m128 _S01_1 = _mm_load_ps(S0 + sx1 + 4); + __m128 _S01_2 = _mm_load_ps(S0 + sx2 + 4); + __m128 _S01_3 = _mm_load_ps(S0 + sx3 + 4); + + __m512 _S00 = combine4x4_ps(_S00_0, _S00_1, _S00_2, _S00_3); + __m512 _S01 = combine4x4_ps(_S01_0, _S01_1, _S01_2, _S01_3); + + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + _mm512_storeu_ps(rows0p + dx * 4, _rows0); + + __m128 _S10_0 = _mm_load_ps(S1 + sx0); + __m128 _S10_1 = _mm_load_ps(S1 + sx1); + __m128 _S10_2 = _mm_load_ps(S1 + sx2); + __m128 _S10_3 = _mm_load_ps(S1 + sx3); + __m128 _S11_0 = _mm_load_ps(S1 + sx0 + 4); + __m128 _S11_1 = _mm_load_ps(S1 + sx1 + 4); + __m128 _S11_2 = _mm_load_ps(S1 + sx2 + 4); + __m128 _S11_3 = _mm_load_ps(S1 + sx3 + 4); + + __m512 _S10 = combine4x4_ps(_S10_0, _S10_1, _S10_2, _S10_3); + __m512 _S11 = combine4x4_ps(_S11_0, _S11_1, _S11_2, _S11_3); + + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows1p + dx * 4, _rows1); + + alphap += 8; + } +#endif // __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 4; + int sx1 = xofs[dx + 1] * 4; + + __m256 _a0 = _mm256_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2]); + __m256 _a1 = _mm256_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3]); + + __m256 _S00 = combine4x2_ps(_mm_load_ps(S0 + sx0), _mm_load_ps(S0 + sx1)); + __m256 _S01 = combine4x2_ps(_mm_load_ps(S0 + sx0 + 4), _mm_load_ps(S0 + sx1 + 4)); + __m256 _S10 = combine4x2_ps(_mm_load_ps(S1 + sx0), _mm_load_ps(S1 + sx1)); + __m256 _S11 = combine4x2_ps(_mm_load_ps(S1 + sx0 + 4), _mm_load_ps(S1 + sx1 + 4)); + + __m256 _rows0 = _mm256_mul_ps(_S00, _a0); + __m256 _rows1 = _mm256_mul_ps(_S10, _a0); + _rows0 = _mm256_comp_fmadd_ps(_S01, _a1, _rows0); + _rows1 = _mm256_comp_fmadd_ps(_S11, _a1, _rows1); + _mm256_storeu_ps(rows0p + dx * 4, _rows0); + _mm256_storeu_ps(rows1p + dx * 4, _rows1); + + alphap += 4; + } +#endif // __AVX__ for (; dx < w; dx++) { int sx = xofs[dx] * 4; @@ -87,25 +206,7 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, prev_sy1 = sy; // vresize - __m128 _b0 = _mm_set1_ps(beta[0]); - __m128 _b1 = _mm_set1_ps(beta[1]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m128 _rows0 = _mm_load_ps(rows0p); - __m128 _rows1 = _mm_load_ps(rows1p); - __m128 _Dp = _mm_mul_ps(_rows0, _b0); - _Dp = _mm_comp_fmadd_ps(_rows1, _b1, _Dp); - _mm_store_ps(Dp, _Dp); - - Dp += 4; - rows0p += 4; - rows1p += 4; - } + vresize_bilinear(rows0, rows1, dst.row(dy), w * 4, beta[0], beta[1]); beta += 2; } diff --git a/src/layer/x86/interp_bilinear_pack8.h b/src/layer/x86/interp_bilinear_pack8.h index 71efb8175e0..657949faf60 100644 --- a/src/layer/x86/interp_bilinear_pack8.h +++ b/src/layer/x86/interp_bilinear_pack8.h @@ -33,6 +33,25 @@ static void resize_bilinear_image_pack8(const Mat& src, Mat& dst, float* alpha, const float* alphap = alpha; float* rows1p = rows1; int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3]); + + __m512 _S10 = combine8x2_ps(_mm256_load_ps(S1 + sx0), _mm256_load_ps(S1 + sx1)); + __m512 _S11 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 8), _mm256_load_ps(S1 + sx1 + 8)); + + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows1p + dx * 8, _rows1); + + alphap += 4; + } +#endif // __AVX512F__ for (; dx < w; dx++) { int sx = xofs[dx] * 8; @@ -60,6 +79,30 @@ static void resize_bilinear_image_pack8(const Mat& src, Mat& dst, float* alpha, float* rows0p = rows0; float* rows1p = rows1; int dx = 0; +#if __AVX512F__ + for (; dx + 1 < w; dx += 2) + { + int sx0 = xofs[dx] * 8; + int sx1 = xofs[dx + 1] * 8; + + __m512 _a0 = _mm512_setr_ps(alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[0], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2], alphap[2]); + __m512 _a1 = _mm512_setr_ps(alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[1], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3], alphap[3]); + + __m512 _S00 = combine8x2_ps(_mm256_load_ps(S0 + sx0), _mm256_load_ps(S0 + sx1)); + __m512 _S01 = combine8x2_ps(_mm256_load_ps(S0 + sx0 + 8), _mm256_load_ps(S0 + sx1 + 8)); + __m512 _S10 = combine8x2_ps(_mm256_load_ps(S1 + sx0), _mm256_load_ps(S1 + sx1)); + __m512 _S11 = combine8x2_ps(_mm256_load_ps(S1 + sx0 + 8), _mm256_load_ps(S1 + sx1 + 8)); + + __m512 _rows0 = _mm512_mul_ps(_S00, _a0); + __m512 _rows1 = _mm512_mul_ps(_S10, _a0); + _rows0 = _mm512_fmadd_ps(_S01, _a1, _rows0); + _rows1 = _mm512_fmadd_ps(_S11, _a1, _rows1); + _mm512_storeu_ps(rows0p + dx * 8, _rows0); + _mm512_storeu_ps(rows1p + dx * 8, _rows1); + + alphap += 4; + } +#endif // __AVX512F__ for (; dx < w; dx++) { int sx = xofs[dx] * 8; @@ -87,25 +130,7 @@ static void resize_bilinear_image_pack8(const Mat& src, Mat& dst, float* alpha, prev_sy1 = sy; // vresize - __m256 _b0 = _mm256_set1_ps(beta[0]); - __m256 _b1 = _mm256_set1_ps(beta[1]); - - float* rows0p = rows0; - float* rows1p = rows1; - float* Dp = dst.row(dy); - - for (int dx = 0; dx < w; dx++) - { - __m256 _rows0 = _mm256_load_ps(rows0p); - __m256 _rows1 = _mm256_load_ps(rows1p); - __m256 _Dp = _mm256_mul_ps(_rows0, _b0); - _Dp = _mm256_comp_fmadd_ps(_rows1, _b1, _Dp); - _mm256_store_ps(Dp, _Dp); - - Dp += 8; - rows0p += 8; - rows1p += 8; - } + vresize_bilinear(rows0, rows1, dst.row(dy), w * 8, beta[0], beta[1]); beta += 2; } diff --git a/src/layer/x86/interp_x86.cpp b/src/layer/x86/interp_x86.cpp index 68352581798..0fe47cc3d6a 100644 --- a/src/layer/x86/interp_x86.cpp +++ b/src/layer/x86/interp_x86.cpp @@ -10,6 +10,9 @@ #endif // __AVX__ #endif // __SSE2__ +#include + +#include "cpu.h" #include "x86_usability.h" namespace ncnn { @@ -139,349 +142,48 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& if (top_blob.empty()) return -100; -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) + if (resize_type == 1) // nearest { - if (resize_type == 1) // nearest - { - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - for (int x = 0; x < outw; x++) - { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m512 _p = _mm512_load_ps(ptr + in_x * 16); - _mm512_store_ps(outptr, _p); - - outptr += 16; - } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outw * 2]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - linear_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; - - for (int x = 0; x < outw; x++) - { - int sx = xofs[x] * 16; - const float* Sp = ptr + sx; - - __m512 _a0 = _mm512_set1_ps(alphap[0]); - __m512 _a1 = _mm512_set1_ps(alphap[1]); - - __m512 _S0 = _mm512_load_ps(Sp); - __m512 _S1 = _mm512_load_ps(Sp + 16); - __m512 _p = _mm512_mul_ps(_S0, _a0); - _p = _mm512_fmadd_ps(_S1, _a1, _p); - _mm512_store_ps(outptr, _p); - - alphap += 2; - outptr += 16; - } - } - - delete[] buf; - } + const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - if (resize_type == 3) // bicubic + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - int* buf = new int[outw + outw * 4]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) + const float* ptr = bottom_blob.row(y); + float* outptr = top_blob.row(y); + for (int x = 0; x < outw; x++) { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; + int in_x = std::min((int)(x * ws), (w - 1)); + const float* Sp = ptr + in_x * elempack; - for (int x = 0; x < outw; x++) + int ep = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ep + 15 < elempack; ep += 16) { - int sx = xofs[x] * 16; - const float* Sp = ptr + sx; - - __m512 _a0 = _mm512_set1_ps(alphap[0]); - __m512 _a1 = _mm512_set1_ps(alphap[1]); - __m512 _a2 = _mm512_set1_ps(alphap[2]); - __m512 _a3 = _mm512_set1_ps(alphap[3]); - - __m512 _S0 = _mm512_load_ps(Sp - 16); - __m512 _S1 = _mm512_load_ps(Sp + 0); - __m512 _S2 = _mm512_load_ps(Sp + 16); - __m512 _S3 = _mm512_load_ps(Sp + 32); - __m512 _p = _mm512_mul_ps(_S0, _a0); - _p = _mm512_fmadd_ps(_S1, _a1, _p); - _p = _mm512_fmadd_ps(_S2, _a2, _p); - _p = _mm512_fmadd_ps(_S3, _a3, _p); - _mm512_store_ps(outptr, _p); - - alphap += 4; - outptr += 16; + __m512 _p = _mm512_load_ps(Sp + ep); + _mm512_store_ps(outptr + ep, _p); } - } - - delete[] buf; - } - - return 0; - } #endif // __AVX512F__ - - if (elempack == 8) - { - if (resize_type == 1) // nearest - { - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - for (int x = 0; x < outw; x++) - { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m256 _p = _mm256_load_ps(ptr + in_x * 8); - _mm256_store_ps(outptr, _p); - - outptr += 8; - } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outw * 2]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - linear_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; - - for (int x = 0; x < outw; x++) - { - int sx = xofs[x] * 8; - const float* Sp = ptr + sx; - - __m256 _a0 = _mm256_set1_ps(alphap[0]); - __m256 _a1 = _mm256_set1_ps(alphap[1]); - - __m256 _S0 = _mm256_load_ps(Sp); - __m256 _S1 = _mm256_load_ps(Sp + 8); - __m256 _p = _mm256_mul_ps(_S0, _a0); - _p = _mm256_comp_fmadd_ps(_S1, _a1, _p); - _mm256_store_ps(outptr, _p); - - alphap += 2; - outptr += 8; - } - } - - delete[] buf; - } - - if (resize_type == 3) // bicubic - { - int* buf = new int[outw + outw * 4]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; - - for (int x = 0; x < outw; x++) + for (; ep + 7 < elempack; ep += 8) { - int sx = xofs[x] * 8; - const float* Sp = ptr + sx; - - __m256 _a0 = _mm256_set1_ps(alphap[0]); - __m256 _a1 = _mm256_set1_ps(alphap[1]); - __m256 _a2 = _mm256_set1_ps(alphap[2]); - __m256 _a3 = _mm256_set1_ps(alphap[3]); - - __m256 _S0 = _mm256_load_ps(Sp - 8); - __m256 _S1 = _mm256_load_ps(Sp + 0); - __m256 _S2 = _mm256_load_ps(Sp + 8); - __m256 _S3 = _mm256_load_ps(Sp + 16); - __m256 _p = _mm256_mul_ps(_S0, _a0); - _p = _mm256_comp_fmadd_ps(_S1, _a1, _p); - _p = _mm256_comp_fmadd_ps(_S2, _a2, _p); - _p = _mm256_comp_fmadd_ps(_S3, _a3, _p); - _mm256_store_ps(outptr, _p); - - alphap += 4; - outptr += 8; + __m256 _p = _mm256_load_ps(Sp + ep); + _mm256_store_ps(outptr + ep, _p); } - } - - delete[] buf; - } - - return 0; - } #endif // __AVX__ - - if (elempack == 4) - { - if (resize_type == 1) // nearest - { - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - for (int x = 0; x < outw; x++) + for (; ep + 3 < elempack; ep += 4) { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m128 _p = _mm_load_ps(ptr + in_x * 4); - _mm_store_ps(outptr, _p); - - outptr += 4; + __m128 _p = _mm_load_ps(Sp + ep); + _mm_store_ps(outptr + ep, _p); } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outw * 2]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - linear_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; - - for (int x = 0; x < outw; x++) - { - int sx = xofs[x] * 4; - const float* Sp = ptr + sx; - - __m128 _a0 = _mm_set1_ps(alphap[0]); - __m128 _a1 = _mm_set1_ps(alphap[1]); - - __m128 _S0 = _mm_load_ps(Sp); - __m128 _S1 = _mm_load_ps(Sp + 4); - __m128 _p = _mm_mul_ps(_S0, _a0); - _p = _mm_comp_fmadd_ps(_S1, _a1, _p); - _mm_store_ps(outptr, _p); - - alphap += 2; - outptr += 4; - } - } - - delete[] buf; - } - - if (resize_type == 3) // bicubic - { - int* buf = new int[outw + outw * 4]; - - int* xofs = buf; - float* alpha = (float*)(buf + outw); - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - const float* alphap = alpha; - - for (int x = 0; x < outw; x++) +#endif // __SSE2__ + for (; ep < elempack; ep++) { - int sx = xofs[x] * 4; - const float* Sp = ptr + sx; - - __m128 _a0 = _mm_set1_ps(alphap[0]); - __m128 _a1 = _mm_set1_ps(alphap[1]); - __m128 _a2 = _mm_set1_ps(alphap[2]); - __m128 _a3 = _mm_set1_ps(alphap[3]); - - __m128 _S0 = _mm_load_ps(Sp - 4); - __m128 _S1 = _mm_load_ps(Sp + 0); - __m128 _S2 = _mm_load_ps(Sp + 4); - __m128 _S3 = _mm_load_ps(Sp + 8); - __m128 _p = _mm_mul_ps(_S0, _a0); - _p = _mm_comp_fmadd_ps(_S1, _a1, _p); - _p = _mm_comp_fmadd_ps(_S2, _a2, _p); - _p = _mm_comp_fmadd_ps(_S3, _a3, _p); - _mm_store_ps(outptr, _p); - - alphap += 4; - outptr += 4; + outptr[ep] = Sp[ep]; } - } - - delete[] buf; - } - return 0; - } -#endif // __SSE2__ - - if (resize_type == 1) // nearest - { - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int y = 0; y < h; y++) - { - const float* ptr = bottom_blob.row(y); - float* outptr = top_blob.row(y); - for (int x = 0; x < outw; x++) - { - int in_x = std::min((int)(x * ws), (w - 1)); - *outptr++ = ptr[in_x]; + outptr += elempack; } } } @@ -504,12 +206,61 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& for (int x = 0; x < outw; x++) { - int sx = xofs[x]; + int sx = xofs[x] * elempack; const float* Sp = ptr + sx; float a0 = alphap[0]; float a1 = alphap[1]; - *outptr++ = Sp[0] * a0 + Sp[1] * a1; + + int ep = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + { + __m512 _a0 = _mm512_set1_ps(a0); + __m512 _a1 = _mm512_set1_ps(a1); + for (; ep + 15 < elempack; ep += 16) + { + __m512 _S0 = _mm512_load_ps(Sp + ep); + __m512 _S1 = _mm512_load_ps(Sp + ep + elempack); + __m512 _p = _mm512_mul_ps(_S0, _a0); + _p = _mm512_fmadd_ps(_S1, _a1, _p); + _mm512_store_ps(outptr + ep, _p); + } + } +#endif // __AVX512F__ + { + __m256 _a0 = _mm256_set1_ps(a0); + __m256 _a1 = _mm256_set1_ps(a1); + for (; ep + 7 < elempack; ep += 8) + { + __m256 _S0 = _mm256_load_ps(Sp + ep); + __m256 _S1 = _mm256_load_ps(Sp + ep + elempack); + __m256 _p = _mm256_mul_ps(_S0, _a0); + _p = _mm256_comp_fmadd_ps(_S1, _a1, _p); + _mm256_store_ps(outptr + ep, _p); + } + } +#endif // __AVX__ + { + __m128 _a0 = _mm_set1_ps(a0); + __m128 _a1 = _mm_set1_ps(a1); + for (; ep + 3 < elempack; ep += 4) + { + __m128 _S0 = _mm_load_ps(Sp + ep); + __m128 _S1 = _mm_load_ps(Sp + ep + elempack); + __m128 _p = _mm_mul_ps(_S0, _a0); + _p = _mm_comp_fmadd_ps(_S1, _a1, _p); + _mm_store_ps(outptr + ep, _p); + } + } +#endif // __SSE2__ + for (; ep < elempack; ep++) + { + outptr[ep] = Sp[ep] * a0 + Sp[ep + elempack] * a1; + } + alphap += 2; + outptr += elempack; } } @@ -534,293 +285,99 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& for (int x = 0; x < outw; x++) { - int sx = xofs[x]; + int sx = xofs[x] * elempack; const float* Sp = ptr + sx; float a0 = alphap[0]; float a1 = alphap[1]; float a2 = alphap[2]; float a3 = alphap[3]; - *outptr++ = Sp[-1] * a0 + Sp[0] * a1 + Sp[1] * a2 + Sp[2] * a3; - alphap += 4; - } - } - - delete[] buf; - } - - return 0; - } - - if (outw == w && outh == h) - { - top_blob = bottom_blob; - return 0; - } - - top_blob.create(outw, outh, channels, elemsize, elempack, opt.blob_allocator); - if (top_blob.empty()) - return -100; + int ep = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) - { - if (resize_type == 1) // nearest - { - const float hs = (output_height || !size_expr.empty()) ? h / (float)outh : 1.f / height_scale; - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - for (int y = 0; y < outh; y++) - { - int in_y = std::min((int)(y * hs), (h - 1)); - - const float* ptr = src.row(in_y); - float* outptr = dst.row(y); - for (int x = 0; x < outw; x++) { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m512 _p = _mm512_load_ps(ptr + in_x * 16); - _mm512_store_ps(outptr, _p); - - outptr += 16; + __m512 _a0 = _mm512_set1_ps(a0); + __m512 _a1 = _mm512_set1_ps(a1); + __m512 _a2 = _mm512_set1_ps(a2); + __m512 _a3 = _mm512_set1_ps(a3); + for (; ep + 15 < elempack; ep += 16) + { + __m512 _S0 = _mm512_load_ps(Sp + ep - elempack); + __m512 _S1 = _mm512_load_ps(Sp + ep); + __m512 _S2 = _mm512_load_ps(Sp + ep + elempack); + __m512 _S3 = _mm512_load_ps(Sp + ep + elempack * 2); + __m512 _p = _mm512_mul_ps(_S0, _a0); + _p = _mm512_fmadd_ps(_S1, _a1, _p); + _p = _mm512_fmadd_ps(_S2, _a2, _p); + _p = _mm512_fmadd_ps(_S3, _a3, _p); + _mm512_store_ps(outptr + ep, _p); + } } - } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outh + outw * 2 + outh * 2]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 2]; - float* beta = (float*)(buf + outw + outh + outw * 2); //new float[outh * 2]; - - linear_coeffs(w, outw, xofs, alpha, align_corner); - linear_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - resize_bilinear_image_pack16(src, dst, alpha, xofs, beta, yofs); - } - - delete[] buf; - } - - if (resize_type == 3) // bicubic - { - int* buf = new int[outw + outh + outw * 4 + outh * 4]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 4]; - float* beta = (float*)(buf + outw + outh + outw * 4); //new float[outh * 4]; - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - cubic_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - resize_bicubic_image_pack16(src, dst, alpha, xofs, beta, yofs); - } - - delete[] buf; - } - - return 0; - } #endif // __AVX512F__ - - if (elempack == 8) - { - if (resize_type == 1) // nearest - { - const float hs = (output_height || !size_expr.empty()) ? h / (float)outh : 1.f / height_scale; - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - for (int y = 0; y < outh; y++) - { - int in_y = std::min((int)(y * hs), (h - 1)); - - const float* ptr = src.row(in_y); - float* outptr = dst.row(y); - for (int x = 0; x < outw; x++) { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m256 _p = _mm256_load_ps(ptr + in_x * 8); - _mm256_store_ps(outptr, _p); - - outptr += 8; + __m256 _a0 = _mm256_set1_ps(a0); + __m256 _a1 = _mm256_set1_ps(a1); + __m256 _a2 = _mm256_set1_ps(a2); + __m256 _a3 = _mm256_set1_ps(a3); + for (; ep + 7 < elempack; ep += 8) + { + __m256 _S0 = _mm256_load_ps(Sp + ep - elempack); + __m256 _S1 = _mm256_load_ps(Sp + ep); + __m256 _S2 = _mm256_load_ps(Sp + ep + elempack); + __m256 _S3 = _mm256_load_ps(Sp + ep + elempack * 2); + __m256 _p = _mm256_mul_ps(_S0, _a0); + _p = _mm256_comp_fmadd_ps(_S1, _a1, _p); + _p = _mm256_comp_fmadd_ps(_S2, _a2, _p); + _p = _mm256_comp_fmadd_ps(_S3, _a3, _p); + _mm256_store_ps(outptr + ep, _p); + } } - } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outh + outw * 2 + outh * 2]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 2]; - float* beta = (float*)(buf + outw + outh + outw * 2); //new float[outh * 2]; - - linear_coeffs(w, outw, xofs, alpha, align_corner); - linear_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - resize_bilinear_image_pack8(src, dst, alpha, xofs, beta, yofs); - } - - delete[] buf; - } - - if (resize_type == 3) // bicubic - { - int* buf = new int[outw + outh + outw * 4 + outh * 4]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 4]; - float* beta = (float*)(buf + outw + outh + outw * 4); //new float[outh * 4]; - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - cubic_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - resize_bicubic_image_pack8(src, dst, alpha, xofs, beta, yofs); - } - - delete[] buf; - } - - return 0; - } #endif // __AVX__ - - if (elempack == 4) - { - if (resize_type == 1) // nearest - { - const float hs = (output_height || !size_expr.empty()) ? h / (float)outh : 1.f / height_scale; - const float ws = (output_width || !size_expr.empty()) ? w / (float)outw : 1.f / width_scale; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - for (int y = 0; y < outh; y++) - { - int in_y = std::min((int)(y * hs), (h - 1)); - - const float* ptr = src.row(in_y); - float* outptr = dst.row(y); - for (int x = 0; x < outw; x++) { - int in_x = std::min((int)(x * ws), (w - 1)); - - __m128 _p = _mm_load_ps(ptr + in_x * 4); - _mm_store_ps(outptr, _p); - - outptr += 4; + __m128 _a0 = _mm_set1_ps(a0); + __m128 _a1 = _mm_set1_ps(a1); + __m128 _a2 = _mm_set1_ps(a2); + __m128 _a3 = _mm_set1_ps(a3); + for (; ep + 3 < elempack; ep += 4) + { + __m128 _S0 = _mm_load_ps(Sp + ep - elempack); + __m128 _S1 = _mm_load_ps(Sp + ep); + __m128 _S2 = _mm_load_ps(Sp + ep + elempack); + __m128 _S3 = _mm_load_ps(Sp + ep + elempack * 2); + __m128 _p = _mm_mul_ps(_S0, _a0); + _p = _mm_comp_fmadd_ps(_S1, _a1, _p); + _p = _mm_comp_fmadd_ps(_S2, _a2, _p); + _p = _mm_comp_fmadd_ps(_S3, _a3, _p); + _mm_store_ps(outptr + ep, _p); + } + } +#endif // __SSE2__ + for (; ep < elempack; ep++) + { + outptr[ep] = Sp[ep - elempack] * a0 + Sp[ep] * a1 + Sp[ep + elempack] * a2 + Sp[ep + elempack * 2] * a3; } - } - } - } - - if (resize_type == 2) // bilinear - { - int* buf = new int[outw + outh + outw * 2 + outh * 2]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 2]; - float* beta = (float*)(buf + outw + outh + outw * 2); //new float[outh * 2]; - - linear_coeffs(w, outw, xofs, alpha, align_corner); - linear_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - resize_bilinear_image_pack4(src, dst, alpha, xofs, beta, yofs); + alphap += 4; + outptr += elempack; + } } delete[] buf; } - if (resize_type == 3) // bicubic - { - int* buf = new int[outw + outh + outw * 4 + outh * 4]; - - int* xofs = buf; //new int[outw]; - int* yofs = buf + outw; //new int[outh]; - - float* alpha = (float*)(buf + outw + outh); //new float[outw * 4]; - float* beta = (float*)(buf + outw + outh + outw * 4); //new float[outh * 4]; - - cubic_coeffs(w, outw, xofs, alpha, align_corner); - cubic_coeffs(h, outh, yofs, beta, align_corner); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const Mat src = bottom_blob.channel(q); - Mat dst = top_blob.channel(q); - - resize_bicubic_image_pack4(src, dst, alpha, xofs, beta, yofs); - } - - delete[] buf; - } + return 0; + } + if (outw == w && outh == h) + { + top_blob = bottom_blob; return 0; } -#endif // __SSE2__ + + top_blob.create(outw, outh, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; if (resize_type == 1) // nearest { @@ -842,7 +399,11 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& for (int x = 0; x < outw; x++) { int in_x = std::min((int)(x * ws), (w - 1)); - *outptr++ = ptr[in_x]; + const float* Sp = ptr + in_x * elempack; + + memcpy(outptr, Sp, elempack * sizeof(float)); + + outptr += elempack; } } } @@ -867,7 +428,28 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& const Mat src = bottom_blob.channel(q); Mat dst = top_blob.channel(q); - resize_bilinear_image(src, dst, alpha, xofs, beta, yofs); +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + resize_bilinear_image_pack16(src, dst, alpha, xofs, beta, yofs); + } +#endif // __AVX512F__ + if (elempack == 8) + { + resize_bilinear_image_pack8(src, dst, alpha, xofs, beta, yofs); + } +#endif // __AVX__ + if (elempack == 4) + { + resize_bilinear_image_pack4(src, dst, alpha, xofs, beta, yofs); + } +#endif // __SSE2__ + if (elempack == 1) + { + resize_bilinear_image(src, dst, alpha, xofs, beta, yofs); + } } delete[] buf; @@ -892,7 +474,28 @@ int Interp_x86::forward(const std::vector& bottom_blobs, std::vector& const Mat src = bottom_blob.channel(q); Mat dst = top_blob.channel(q); - resize_bicubic_image(src, dst, alpha, xofs, beta, yofs); +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + resize_bicubic_image_pack16(src, dst, alpha, xofs, beta, yofs); + } +#endif // __AVX512F__ + if (elempack == 8) + { + resize_bicubic_image_pack8(src, dst, alpha, xofs, beta, yofs); + } +#endif // __AVX__ + if (elempack == 4) + { + resize_bicubic_image_pack4(src, dst, alpha, xofs, beta, yofs); + } +#endif // __SSE2__ + if (elempack == 1) + { + resize_bicubic_image(src, dst, alpha, xofs, beta, yofs); + } } delete[] buf; diff --git a/src/layer/x86/interp_x86_avx2.cpp b/src/layer/x86/interp_x86_avx2.cpp new file mode 100644 index 00000000000..fcd1878dce7 --- /dev/null +++ b/src/layer/x86/interp_x86_avx2.cpp @@ -0,0 +1,23 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "interp_bilinear.h" +#include "interp_bicubic.h" + +void resize_bilinear_image_avx2(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs) +{ + resize_bilinear_image(src, dst, alpha, xofs, beta, yofs); +} + +void resize_bicubic_image_avx2(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs) +{ + resize_bicubic_image(src, dst, alpha, xofs, beta, yofs); +} + +} // namespace ncnn diff --git a/tests/test_interp.cpp b/tests/test_interp.cpp index 140ac162a87..4f159453a34 100644 --- a/tests/test_interp.cpp +++ b/tests/test_interp.cpp @@ -496,6 +496,201 @@ static int test_interp_6() || test_interp_ref(c, 1, 14, 17); } +// dims=3, nearest, pack8 channels (c%8==0 && c%16!=0) and pack1 channels (c%4!=0) +static int test_interp_7() +{ + ncnn::Mat a = RandomMat(15, 16, 8); + ncnn::Mat b = RandomMat(14, 17, 24); + ncnn::Mat c = RandomMat(13, 14, 3); + + return 0 + || test_interp(a, 1, 2.f, 2.f, 0, 0) + || test_interp(a, 1, 4.f, 0.5f, 0, 0) + || test_interp(a, 1, 1.2f, 1.2f, 0, 0) + || test_interp(a, 1, 0.8f, 0.8f, 0, 0) + || test_interp(a, 1, 1.f, 1.f, 10, 12) + || test_interp(a, 1, 1.f, 1.f, 15, 16) + || test_interp_ref(a, 1, 10, 12) + || test_interp_ref(a, 1, 15, 16) + + || test_interp(b, 1, 2.f, 2.f, 0, 0) + || test_interp(b, 1, 0.5f, 0.5f, 0, 0) + || test_interp(b, 1, 1.f, 1.f, 10, 12) + || test_interp(b, 1, 1.f, 1.f, 14, 17) + || test_interp_ref(b, 1, 10, 12) + || test_interp_ref(b, 1, 14, 17) + + || test_interp(c, 1, 2.f, 2.f, 0, 0) + || test_interp(c, 1, 0.5f, 0.5f, 0, 0) + || test_interp(c, 1, 1.f, 1.f, 10, 12) + || test_interp_ref(c, 1, 10, 12); +} + +// dims=3, bilinear, pack8 channels and pack1 channels +static int test_interp_8() +{ + ncnn::Mat a = RandomMat(15, 16, 8); + ncnn::Mat b = RandomMat(14, 17, 24); + ncnn::Mat c = RandomMat(13, 14, 3); + + return 0 + || test_interp(a, 2, 2.f, 2.f, 0, 0) + || test_interp(a, 2, 4.f, 0.5f, 0, 0) + || test_interp(a, 2, 1.2f, 1.2f, 0, 0) + || test_interp(a, 2, 0.8f, 0.8f, 0, 0) + || test_interp(a, 2, 1.f, 1.f, 10, 12) + || test_interp(a, 2, 1.f, 1.f, 15, 16) + || test_interp_align_corner(a, 2, 2.f, 2.f, 0, 0, 1) + || test_interp_align_corner(a, 2, 0.8f, 0.8f, 0, 0, 1) + || test_interp_align_corner(a, 2, 1.f, 1.f, 10, 12, 1) + || test_interp_ref(a, 2, 10, 12) + || test_interp_ref(a, 2, 15, 16) + + || test_interp(b, 2, 2.f, 2.f, 0, 0) + || test_interp(b, 2, 0.5f, 0.5f, 0, 0) + || test_interp(b, 2, 1.f, 1.f, 10, 12) + || test_interp_align_corner(b, 2, 2.f, 2.f, 0, 0, 1) + || test_interp_ref(b, 2, 10, 12) + + || test_interp(c, 2, 2.f, 2.f, 0, 0) + || test_interp(c, 2, 0.5f, 0.5f, 0, 0) + || test_interp(c, 2, 1.f, 1.f, 10, 12) + || test_interp_align_corner(c, 2, 2.f, 2.f, 0, 0, 1) + || test_interp_ref(c, 2, 10, 12); +} + +// dims=3, bicubic, pack8 channels and pack1 channels +// Uses large output sizes (>= 16) to exercise AVX512 gather-based hresize loops in pack1 path +// Uses various downscale ratios to trigger sy row-reuse cases (+1/+2/+3/full) +static int test_interp_9() +{ + ncnn::Mat a = RandomMat(16, 17, 8); + ncnn::Mat b = RandomMat(18, 19, 24); + ncnn::Mat c = RandomMat(13, 14, 3); + ncnn::Mat d = RandomMat(32, 32, 5); + + return 0 + || test_interp(a, 3, 2.f, 2.f, 0, 0) + || test_interp(a, 3, 4.f, 0.5f, 0, 0) + || test_interp(a, 3, 1.2f, 1.2f, 0, 0) + || test_interp(a, 3, 0.8f, 0.8f, 0, 0) + || test_interp(a, 3, 1.f, 1.f, 10, 12) + || test_interp(a, 3, 1.f, 1.f, 6, 7) + || test_interp(a, 3, 1.f, 1.f, 16, 17) + || test_interp_align_corner(a, 3, 2.f, 2.f, 0, 0, 1) + || test_interp_align_corner(a, 3, 0.8f, 0.8f, 0, 0, 1) + || test_interp_align_corner(a, 3, 1.f, 1.f, 10, 12, 1) + || test_interp_ref(a, 3, 6, 7) + || test_interp_ref(a, 3, 16, 17) + + || test_interp(b, 3, 2.f, 2.f, 0, 0) + || test_interp(b, 3, 0.5f, 0.5f, 0, 0) + || test_interp(b, 3, 1.f, 1.f, 10, 12) + || test_interp_align_corner(b, 3, 2.f, 2.f, 0, 0, 1) + || test_interp_ref(b, 3, 10, 12) + + // pack1, large output width to hit AVX512 gather loops + || test_interp(c, 3, 2.f, 2.f, 0, 0) + || test_interp(c, 3, 0.5f, 0.5f, 0, 0) + || test_interp(c, 3, 1.f, 1.f, 10, 12) + || test_interp(c, 3, 1.f, 1.f, 20, 20) + || test_interp_align_corner(c, 3, 2.f, 2.f, 0, 0, 1) + || test_interp_align_corner(c, 3, 1.f, 1.f, 20, 20, 1) + || test_interp_ref(c, 3, 10, 12) + || test_interp_ref(c, 3, 20, 20) + + // pack1, large input downscaled to trigger sy jumps of +2 and +3 + || test_interp(d, 3, 0.25f, 0.25f, 0, 0) + || test_interp(d, 3, 0.3f, 0.3f, 0, 0) + || test_interp(d, 3, 1.f, 1.f, 8, 20) + || test_interp(d, 3, 1.f, 1.f, 10, 24) + || test_interp(d, 3, 1.f, 1.f, 16, 20) + || test_interp_align_corner(d, 3, 0.25f, 0.25f, 0, 0, 1) + || test_interp_align_corner(d, 3, 1.f, 1.f, 8, 20, 1) + || test_interp_align_corner(d, 3, 1.f, 1.f, 16, 20, 1) + || test_interp_ref(d, 3, 8, 20) + || test_interp_ref(d, 3, 16, 20); +} + +// dims=2, all resize types, pack8 (h=8) and pack1 (h=3) +static int test_interp_10() +{ + ncnn::Mat a = RandomMat(15, 8); + ncnn::Mat b = RandomMat(14, 24); + ncnn::Mat c = RandomMat(13, 3); + + return 0 + // nearest + || test_interp(a, 1, 2.f, 0) + || test_interp(a, 1, 0.5f, 0) + || test_interp(a, 1, 1.f, 12) + || test_interp_ref(a, 1, 12) + + || test_interp(b, 1, 2.f, 0) + || test_interp(b, 1, 1.f, 12) + || test_interp_ref(b, 1, 12) + + || test_interp(c, 1, 2.f, 0) + || test_interp(c, 1, 1.f, 12) + || test_interp_ref(c, 1, 12) + + // bilinear + || test_interp(a, 2, 2.f, 0) + || test_interp(a, 2, 0.5f, 0) + || test_interp(a, 2, 1.f, 12) + || test_interp_align_corner(a, 2, 2.f, 0, 1) + || test_interp_align_corner(a, 2, 1.f, 12, 1) + || test_interp_ref(a, 2, 12) + + || test_interp(b, 2, 2.f, 0) + || test_interp(b, 2, 1.f, 12) + || test_interp_align_corner(b, 2, 2.f, 0, 1) + || test_interp_ref(b, 2, 12) + + || test_interp(c, 2, 2.f, 0) + || test_interp(c, 2, 1.f, 12) + || test_interp_align_corner(c, 2, 2.f, 0, 1) + || test_interp_ref(c, 2, 12) + + // bicubic + || test_interp(a, 3, 2.f, 0) + || test_interp(a, 3, 0.5f, 0) + || test_interp(a, 3, 1.f, 12) + || test_interp(a, 3, 1.f, 7) + || test_interp_align_corner(a, 3, 2.f, 0, 1) + || test_interp_align_corner(a, 3, 1.f, 12, 1) + || test_interp_ref(a, 3, 12) + + || test_interp(b, 3, 2.f, 0) + || test_interp(b, 3, 1.f, 12) + || test_interp_align_corner(b, 3, 2.f, 0, 1) + || test_interp_ref(b, 3, 12) + + || test_interp(c, 3, 2.f, 0) + || test_interp(c, 3, 1.f, 12) + || test_interp_align_corner(c, 3, 2.f, 0, 1) + || test_interp_ref(c, 3, 12); +} + +// dims=1, pack8 (w=8, w=24) +static int test_interp_11() +{ + ncnn::Mat a = RandomMat(8); + ncnn::Mat b = RandomMat(24); + + return 0 + || test_interp(a, 1, 2.f, 3.f, 0, 0) + || test_interp(a, 1, 1.f, 1.f, 10, 12) + || test_interp(a, 1, 1.f, 1.f, 4, 4) + || test_interp_ref(a, 1, 10, 12) + || test_interp_ref(a, 1, 4, 4) + + || test_interp(b, 1, 4.f, 5.f, 0, 0) + || test_interp(b, 1, 1.f, 1.f, 10, 12) + || test_interp_ref(b, 1, 10, 12) + || test_interp_ref(b, 1, 5, 5); +} + int main() { SRAND(7767517); @@ -507,5 +702,10 @@ int main() || test_interp_3() || test_interp_4() || test_interp_5() - || test_interp_6(); + || test_interp_6() + || test_interp_7() + || test_interp_8() + || test_interp_9() + || test_interp_10() + || test_interp_11(); } From 04b5117cb235f23167380b2955af1358428aadf4 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 12 Mar 2026 15:09:43 +0800 Subject: [PATCH 20/36] x86 gemm support bf16 storage (#6598) --- src/layer/arm/gemm_arm.cpp | 12 +- src/layer/x86/gemm_bf16s.h | 2395 ++++++++++++++++++++++++++++++++++++ src/layer/x86/gemm_x86.cpp | 538 ++++++++ src/layer/x86/gemm_x86.h | 3 + 4 files changed, 2945 insertions(+), 3 deletions(-) create mode 100644 src/layer/x86/gemm_bf16s.h diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 60a128e9898..a6e3a3bf33a 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5210,9 +5210,15 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); + TILE_N = std::max(16, tile_size / 16 * 16); +#elif __AVX__ + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); + TILE_N = std::max(1, tile_size); +#endif + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __AVX512F__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 15) / 16 * 16); +#elif __AVX__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + } + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { +#if __AVX512F__ + TILE_M = (constant_TILE_M + 15) / 16 * 16; +#elif __AVX__ + TILE_M = (constant_TILE_M + 7) / 8 * 8; +#elif __SSE2__ + TILE_M = (constant_TILE_M + 3) / 4 * 4; +#else + TILE_M = (constant_TILE_M + 1) / 2 * 2; +#endif + } + + if (constant_TILE_N > 0) + { +#if __AVX512F__ + TILE_N = (constant_TILE_N + 15) / 16 * 16; +#elif __AVX__ + TILE_N = (constant_TILE_N + 3) / 4 * 4; +#elif __SSE2__ + TILE_N = (constant_TILE_N + 3) / 4 * 4; +#else + TILE_N = constant_TILE_N; +#endif + } + + if (constant_TILE_K > 0) + { +#if __AVX512F__ + TILE_K = (constant_TILE_K + 15) / 16 * 16; +#elif __AVX__ + TILE_K = (constant_TILE_K + 7) / 8 * 8; +#elif __SSE2__ + TILE_K = (constant_TILE_K + 3) / 4 * 4; +#else + TILE_K = (constant_TILE_K + 1) / 2 * 2; +#endif + } +} diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index 8dd43b45d3c..0fda36ccd1d 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -24,12 +24,20 @@ namespace ncnn { #include "gemm_int8.h" #endif +#if NCNN_BF16 +#include "gemm_bf16s.h" +#endif + Gemm_x86::Gemm_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif // NCNN_BF16 + nT = 0; } @@ -7205,6 +7213,7 @@ int Gemm_x86::create_pipeline(const Option& opt) #if NCNN_INT8 if (int8_scale_term) { + support_bf16_storage = false; return create_pipeline_int8(opt); } #endif @@ -7356,6 +7365,14 @@ int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& to } #endif + const Mat& bottom_blob = bottom_blobs.empty() ? AT_data : bottom_blobs[0]; + int elembits = bottom_blob.elembits(); + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif + int M; int N; if (constantA && constantB) @@ -8363,4 +8380,525 @@ void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& t #endif } // namespace Gemm_x86_utility +#if NCNN_BF16 +static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // BT is fp32 packed tile + Mat BT(TILE_K * TILE_N, nn_K, nn_N, 4u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + // pack B (bf16 -> fp32) + const int nn_NK = nn_N * nn_K; + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + } + } + + // topT is always needed for bf16 path (accumulate fp32, then convert to bf16) + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + Mat ATX(TILE_K * TILE_M, nn_K, nT, 4u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + } + } + + // always k_end=false, accumulate to topT as fp32 + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + } + + // multiply alpha + if (alpha != 1.f) + { + float* outptr = topT_tile; + int size = max_ii * max_jj; + for (int q = 0; q < size; q++) + { + outptr[q] *= alpha; + } + } + + // convert fp32 topT to bf16 output + unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + } + } + + return 0; +} + +static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, nn_K, nn_N, 4u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // AT is pre-packed fp32 + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + } + + if (alpha != 1.f) + { + float* outptr = topT_tile; + int size = max_ii * max_jj; + for (int q = 0; q < size; q++) + { + outptr[q] *= alpha; + } + } + + unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + } + } + + return 0; +} + +static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + Mat ATX(TILE_K * TILE_M, nn_K, nT, 4u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + // BT is pre-packed fp32 + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + } + } + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + } + + if (alpha != 1.f) + { + float* outptr = topT_tile; + int size = max_ii * max_jj; + for (int q = 0; q < size; q++) + { + outptr[q] *= alpha; + } + } + + unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + } + } + + return 0; +} + +static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + } + + if (alpha != 1.f) + { + float* outptr = topT_tile; + int size = max_ii * max_jj; + for (int q = 0; q < size; q++) + { + outptr[q] *= alpha; + } + } + + unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + } + } + + return 0; +} + +int Gemm_x86::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + broadcast_type_C = 4; + } + + // cast to fp32 + { + Option opt_cast = opt; + opt_cast.blob_allocator = opt.workspace_allocator; + + Mat C_fp32; + cast_bfloat16_to_float32(C, C_fp32, opt_cast); + if (C_fp32.empty()) + return -100; + + C = C_fp32; + } + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat C2; + C2.create_like(C, opt.workspace_allocator); + if (C2.empty()) + return -100; + + const int size = C.total() * C.elempack; + for (int i = 0; i < size; i++) + { + C2[i] = C[i] * beta; + } + + C = C2; + } + } + } + + // bf16 output, elempack=1 only for now + int out_elempack = 1; + size_t out_elemsize = 2u * out_elempack; + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_x86_bf16s(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_x86_bf16s(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_x86_bf16s(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_x86_bf16s(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + + return ret; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/gemm_x86.h b/src/layer/x86/gemm_x86.h index 55d3c19ba88..9710d129d27 100644 --- a/src/layer/x86/gemm_x86.h +++ b/src/layer/x86/gemm_x86.h @@ -18,6 +18,9 @@ class Gemm_x86 : public Gemm virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; protected: +#if NCNN_BF16 + int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif #if NCNN_INT8 int create_pipeline_int8(const Option& opt); int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; From bba0de3a0aff7f761d321b4c3be0d4820100ddbe Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 16 Mar 2026 12:43:06 +0800 Subject: [PATCH 21/36] x86 gemm int8 optimization with alignr (#6600) --- src/layer/x86/gemm_int8.h | 175 ++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 73 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 5bd4e1cb17c..9227e486028 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -12598,12 +12598,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); - __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pA3 = _mm512_alignr_epi8(_pA2, _pA2, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); @@ -12626,9 +12626,9 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); - __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift1 = _mm512_alignr_epi8(_w_shift0, _w_shift0, 8); __m512i _w_shift2 = _mm512_shuffle_i32x4(_w_shift0, _w_shift0, _MM_SHUFFLE(2, 3, 0, 1)); - __m512i _w_shift3 = _mm512_shuffle_epi32(_w_shift2, _MM_PERM_BADC); + __m512i _w_shift3 = _mm512_alignr_epi8(_w_shift2, _w_shift2, 8); _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); @@ -12660,17 +12660,17 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 2301 6745 ab89 efcd // 4567 0123 cdef 89ab // 6745 2301 efcd ab89 - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); - __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); + __m512i _pA3 = _mm512_alignr_epi8(_pA2, _pA2, 8); // 0123 4567 89ab cdef // 1230 5674 9ab8 defc // 89ab cdef 0123 4567 // 9ab8 defc 1230 5674 - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -12790,10 +12790,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); __m512i _pB0 = combine8x2_epi32(_pB, _pB); - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); @@ -12808,7 +12808,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); - __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift1 = _mm512_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); @@ -12830,16 +12830,16 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 89ab cdef // 2301 6745 ab89 efcd - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); // 0123 4567 0123 4567 // 1230 5674 1230 5674 // 4567 0123 4567 0123 // 5674 1230 5674 1230 __m512i _pB0 = combine8x2_epi32(_pBB, _pBB); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -12865,7 +12865,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB0 = combine4x2_epi32(_pB, _pB); __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB2 = _mm256_alignr_epi8(_pB0, _pB0, 8); __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); @@ -12922,8 +12922,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); __m512i _pB0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pB)); - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); @@ -12934,7 +12934,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); - __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift1 = _mm512_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); @@ -12952,11 +12952,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 89ab cdef // 2301 6745 ab89 efcd - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); // 0123 0123 0123 0123 // 1230 1230 1230 1230 - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13017,7 +13017,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); pA += 64; @@ -13043,7 +13043,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0101 0101 0101 0101 // 1010 1010 1010 1010 - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13196,10 +13196,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); - __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pA11 = _mm512_alignr_epi8(_pA00, _pA00, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA00); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA00); _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA11); @@ -13215,7 +13215,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); __m512i _w_shift00 = combine8x2_epi32(_w_shift0, _w_shift0); - __m512i _w_shift11 = _mm512_shuffle_epi32(_w_shift00, _MM_PERM_BADC); + __m512i _w_shift11 = _mm512_alignr_epi8(_w_shift00, _w_shift00, 8); _sum0 = _mm512_sub_epi32(_sum0, _w_shift00); _sum1 = _mm512_sub_epi32(_sum1, _w_shift00); _sum2 = _mm512_sub_epi32(_sum2, _w_shift11); @@ -13238,15 +13238,15 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 0123 4567 // 2301 6745 2301 6745 __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); - __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); + __m512i _pA11 = _mm512_alignr_epi8(_pA00, _pA00, 8); // 0123 4567 89ab cdef // 1230 5674 9ab8 defc // 4567 0123 cdef 89ab // 5674 1230 defc 9ab8 - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB1); @@ -13272,7 +13272,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB2 = _mm256_alignr_epi8(_pB0, _pB0, 8); __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB0))); @@ -13342,10 +13342,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); - __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB3 = _mm256_alignr_epi8(_pB2, _pB2, 4); #if __AVXVNNIINT8__ _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA0); @@ -13372,7 +13372,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); - __m256i _w_shift1 = _mm256_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _w_shift1 = _mm256_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); @@ -13399,15 +13399,15 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 // 2301 6745 - __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); // 0123 4567 // 1230 5674 // 4567 0123 // 5674 1230 - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB3 = _mm256_alignr_epi8(_pB2, _pB2, 4); _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13431,7 +13431,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pA0, _MM_SHUFFLE(1, 0, 3, 2)), _MM_SHUFFLE(1, 0, 3, 2)); __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB2 = _mm_alignr_epi8(_pB0, _pB0, 8); __m128i _pB3 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0))); @@ -13490,8 +13490,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m128i _pB = _mm_loadu_si128((const __m128i*)pB); __m256i _pB0 = combine4x2_epi32(_pB, _pB); - __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); #if __AVXVNNIINT8__ _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA0); @@ -13510,7 +13510,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); - __m256i _w_shift1 = _mm256_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _w_shift1 = _mm256_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); @@ -13533,11 +13533,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 // 2301 6745 - __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); // 0123 0123 // 1230 1230 - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13599,7 +13599,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); #if __AVXVNNIINT8__ _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA); _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA); @@ -13636,7 +13636,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0101 0101 // 1010 1010 - __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13801,8 +13801,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); @@ -13813,7 +13813,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m512i _w_shift0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); - __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift1 = _mm512_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); @@ -13831,11 +13831,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 0123 0123 0123 // 2301 2301 2301 2301 - __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); // 0123 4567 89ab cdef // 1230 5674 9ab8 defc - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -13938,8 +13938,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); __m256i _pA00 = combine4x2_epi32(_pA0, _pA0); - __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pA11 = _mm256_alignr_epi8(_pA00, _pA00, 8); + __m256i _pB23 = _mm256_alignr_epi8(_pB01, _pB01, 4); #if __AVXVNNIINT8__ _sum0 = _mm256_dpbssd_epi32(_sum0, _pB01, _pA00); _sum1 = _mm256_dpbssd_epi32(_sum1, _pB01, _pA11); @@ -13959,7 +13959,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); __m256i _w_shift00 = combine4x2_epi32(_w_shift0, _w_shift0); - __m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _w_shift11 = _mm256_alignr_epi8(_w_shift00, _w_shift00, 8); _sum0 = _mm256_sub_epi32(_sum0, _w_shift00); _sum1 = _mm256_sub_epi32(_sum1, _w_shift11); _sum2 = _mm256_sub_epi32(_sum2, _w_shift00); @@ -13977,8 +13977,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA00 = _mm256_cvtepi8_epi16(_pA); __m256i _pB01 = _mm256_cvtepi8_epi16(_pB); - __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pA11 = _mm256_alignr_epi8(_pA00, _pA00, 8); + __m256i _pB23 = _mm256_alignr_epi8(_pB01, _pB01, 4); _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA00, _pB01); _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA11, _pB01); @@ -13997,7 +13997,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 // 2301 __m128i _pA0 = _pA; +#if __SSSE3__ + __m128i _pA1 = _mm_alignr_epi8(_pA, _pA, 8); +#else __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(1, 0, 3, 2)); +#endif // 0123 // 4567 @@ -14005,8 +14009,13 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 5674 __m128i _pB0 = _pBl; __m128i _pB1 = _pBh; +#if __SSSE3__ + __m128i _pB2 = _mm_alignr_epi8(_pBl, _pBl, 4); + __m128i _pB3 = _mm_alignr_epi8(_pBh, _pBh, 4); +#else __m128i _pB2 = _mm_shuffle_epi32(_pBl, _MM_SHUFFLE(0, 3, 2, 1)); __m128i _pB3 = _mm_shuffle_epi32(_pBh, _MM_SHUFFLE(0, 3, 2, 1)); +#endif _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -14058,15 +14067,15 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 00112233 // 22330011 __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); - __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pA1 = _mm_alignr_epi8(_pA0, _pA0, 8); // 00112233 // 44556677 // 1.2.3.0. __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); __m128i _pB1 = _mm_unpackhi_epi16(_pB, _pB); - __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB2 = _mm_alignr_epi8(_pB0, _pB0, 4); + __m128i _pB3 = _mm_alignr_epi8(_pB1, _pB1, 4); _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); @@ -14168,8 +14177,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); - __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); - __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pA1 = _mm_alignr_epi8(_pA0, _pA0, 8); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); #if __AVXVNNIINT8__ _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA0); @@ -14188,7 +14197,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); - __m128i _w_shift1 = _mm_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _w_shift1 = _mm_alignr_epi8(_w_shift0, _w_shift0, 8); _sum0 = _mm_sub_epi32(_sum0, _w_shift0); _sum1 = _mm_sub_epi32(_sum1, _w_shift0); _sum2 = _mm_sub_epi32(_sum2, _w_shift1); @@ -14213,12 +14222,20 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 // 2301 __m128i _pA0 = _pA; +#if __SSSE3__ + __m128i _pA1 = _mm_alignr_epi8(_pA, _pA, 8); +#else __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(1, 0, 3, 2)); +#endif // 0123 // 1230 __m128i _pB0 = _pB; +#if __SSSE3__ + __m128i _pB1 = _mm_alignr_epi8(_pB, _pB, 4); +#else __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); +#endif _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -14245,12 +14262,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 00112233 // 22330011 __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); - __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pA1 = _mm_alignr_epi8(_pA0, _pA0, 8); // 00112233 // 1.2.3.0. __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); - __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); @@ -14315,7 +14332,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); - __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); #if __AVXVNNIINT8__ _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA); _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA); @@ -14353,7 +14370,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0101 // 1010 +#if __SSSE3__ + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); +#else __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); +#endif _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); @@ -14545,7 +14566,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m512i _pA = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); pA += 8; @@ -14571,7 +14592,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 4567 89ab cdef // 1230 5674 9ab8 defc - __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); @@ -14649,7 +14670,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, for (; kk + 3 < max_kk; kk += 4) { __m256i _pA00 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); - __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA11 = _mm256_alignr_epi8(_pA00, _pA00, 4); __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); #if __AVXVNNIINT8__ _sum0 = _mm256_dpbssd_epi32(_sum0, _pB01, _pA00); @@ -14665,7 +14686,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, if (max_kk >= 4) { __m256i _w_shift00 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); - __m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _w_shift11 = _mm256_alignr_epi8(_w_shift00, _w_shift00, 4); _sum0 = _mm256_sub_epi32(_sum0, _w_shift00); _sum1 = _mm256_sub_epi32(_sum1, _w_shift11); pA += 8; @@ -14681,7 +14702,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA00 = _mm256_cvtepi8_epi16(_pA); __m256i _pB01 = _mm256_cvtepi8_epi16(_pB); - __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA11 = _mm256_alignr_epi8(_pA00, _pA00, 4); _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA00, _pB01); _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA11, _pB01); @@ -14699,7 +14720,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0101 // 1010 __m128i _pA0 = _pA; +#if __SSSE3__ + __m128i _pA1 = _mm_alignr_epi8(_pA, _pA, 4); +#else __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); +#endif // 0123 // 4567 @@ -14795,7 +14820,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); - __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); #if __AVXVNNIINT8__ _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA); _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA); @@ -14834,7 +14859,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 // 1230 __m128i _pB0 = _pB; +#if __SSSE3__ + __m128i _pB1 = _mm_alignr_epi8(_pB, _pB, 4); +#else __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); +#endif _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); From 9fffd13fae8ba73020c91c0aa4b211119bde4cbc Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 17 Mar 2026 22:57:09 +0800 Subject: [PATCH 22/36] fix binaryop vulkan shader compilation on moltenvk (#6602) --- src/layer/vulkan/shader/binaryop.comp | 2 +- src/layer/vulkan/shader/binaryop_broadcast.comp | 2 +- src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp | 2 +- src/layer/vulkan/shader/binaryop_broadcast_pack4.comp | 2 +- src/layer/vulkan/shader/binaryop_pack4.comp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layer/vulkan/shader/binaryop.comp b/src/layer/vulkan/shader/binaryop.comp index e5c09bcf241..569f6dd95ec 100644 --- a/src/layer/vulkan/shader/binaryop.comp +++ b/src/layer/vulkan/shader/binaryop.comp @@ -103,7 +103,7 @@ void main() #endif if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; - if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 14) res = max(v1, v2) + log(afp(1.0) + exp(min(v1, v2) - max(v1, v2))); if (op_type == 15) res = floor(v1 / v2); if (op_type == 16) res = floor(v2 / v1); if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp index c71d512d35d..9320b5952b4 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast.comp @@ -172,7 +172,7 @@ void main() #endif if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; - if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 14) res = max(v1, v2) + log(afp(1.0) + exp(min(v1, v2) - max(v1, v2))); if (op_type == 15) res = floor(v1 / v2); if (op_type == 16) res = floor(v2 / v1); if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp index fe45c61ad75..cd40e1bd951 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp @@ -103,7 +103,7 @@ void main() #endif if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; - if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 14) res = max(v1, v2) + log(afpvec4(1.0) + exp(min(v1, v2) - max(v1, v2))); if (op_type == 15) res = floor(v1 / v2); if (op_type == 16) res = floor(v2 / v1); if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp index de9df649188..e9eaf656346 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp @@ -172,7 +172,7 @@ void main() #endif if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; - if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 14) res = max(v1, v2) + log(afpvec4(1.0) + exp(min(v1, v2) - max(v1, v2))); if (op_type == 15) res = floor(v1 / v2); if (op_type == 16) res = floor(v2 / v1); if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; diff --git a/src/layer/vulkan/shader/binaryop_pack4.comp b/src/layer/vulkan/shader/binaryop_pack4.comp index 3a4c230acb2..35eb949e165 100644 --- a/src/layer/vulkan/shader/binaryop_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_pack4.comp @@ -98,7 +98,7 @@ void main() #endif if (op_type == 12) res = v1 - trunc(v1 / v2) * v2; if (op_type == 13) res = v2 - trunc(v2 / v1) * v1; - if (op_type == 14) res = max(v1, v2) + log(1.0 + exp(min(v1, v2) - max(v1, v2))); + if (op_type == 14) res = max(v1, v2) + log(afpvec4(1.0) + exp(min(v1, v2) - max(v1, v2))); if (op_type == 15) res = floor(v1 / v2); if (op_type == 16) res = floor(v2 / v1); if (op_type == 17) res = v1 - roundEven(v1 / v2) * v2; From 039f0c043638d74f63fc4c62fd48933b4dddd51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=A8=E7=9A=84=E5=BD=B7=E5=BE=A8?= <56149058+futz12@users.noreply.github.com> Date: Wed, 18 Mar 2026 08:32:11 +0800 Subject: [PATCH 23/36] x86 erf and gelu optimization (#6604) --- src/layer/x86/avx512_mathfun.h | 48 +++++++++ src/layer/x86/avx_mathfun.h | 48 +++++++++ src/layer/x86/erf_x86.cpp | 89 ++++++++++++++++ src/layer/x86/erf_x86.h | 21 ++++ src/layer/x86/gelu_x86.cpp | 187 ++++++++++++++++++++++----------- src/layer/x86/sse_mathfun.h | 48 +++++++++ 6 files changed, 380 insertions(+), 61 deletions(-) create mode 100644 src/layer/x86/erf_x86.cpp create mode 100644 src/layer/x86/erf_x86.h diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index 0068edf79c7..e2e132ca678 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -217,6 +217,54 @@ static NCNN_FORCEINLINE __m512 tanh512_ps(const __m512& x) return dst; } +_PS512_CONST(erf_threshold, 0.927734375f); + +_PS512_CONST(erf_c0, -1.72853470e-5f); +_PS512_CONST(erf_c1, 3.83197126e-4f); +_PS512_CONST(erf_c2, -3.88396438e-3f); +_PS512_CONST(erf_c3, 2.42546219e-2f); +_PS512_CONST(erf_c4, -1.06777877e-1f); +_PS512_CONST(erf_c5, -6.34846687e-1f); +_PS512_CONST(erf_c6, -1.28717512e-1f); + +_PS512_CONST(erf_p0, -5.96761703e-4f); +_PS512_CONST(erf_p1, 4.99119423e-3f); +_PS512_CONST(erf_p2, -2.67681349e-2f); +_PS512_CONST(erf_p3, 1.12819925e-1f); +_PS512_CONST(erf_p4, -3.76125336e-1f); +_PS512_CONST(erf_p5, 1.28379166e-1f); + +static NCNN_FORCEINLINE __m512 erf512_ps(const __m512& a) +{ + __m512 t = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff))); + __m512 s = _mm512_mul_ps(a, a); + + __mmask16 mask = _mm512_cmp_ps_mask(t, *(__m512*)_ps512_erf_threshold, _CMP_GT_OQ); + + __m512 r_large = _mm512_fmadd_ps(*(__m512*)_ps512_erf_c0, t, *(__m512*)_ps512_erf_c1); + __m512 u = _mm512_fmadd_ps(*(__m512*)_ps512_erf_c2, t, *(__m512*)_ps512_erf_c3); + r_large = _mm512_fmadd_ps(r_large, s, u); + r_large = _mm512_fmadd_ps(r_large, t, *(__m512*)_ps512_erf_c4); + r_large = _mm512_fmadd_ps(r_large, t, *(__m512*)_ps512_erf_c5); + r_large = _mm512_fmadd_ps(r_large, t, *(__m512*)_ps512_erf_c6); + r_large = _mm512_fmadd_ps(r_large, t, _mm512_sub_ps(_mm512_setzero_ps(), t)); + r_large = _mm512_sub_ps(*(__m512*)_ps512_1, exp512_ps(r_large)); + + __m512 sign_mask = _mm512_and_ps(a, *(__m512*)_ps512_sign_mask); + r_large = _mm512_xor_ps(r_large, sign_mask); + + __m512 r_small = *(__m512*)_ps512_erf_p0; + r_small = _mm512_fmadd_ps(r_small, s, *(__m512*)_ps512_erf_p1); + r_small = _mm512_fmadd_ps(r_small, s, *(__m512*)_ps512_erf_p2); + r_small = _mm512_fmadd_ps(r_small, s, *(__m512*)_ps512_erf_p3); + r_small = _mm512_fmadd_ps(r_small, s, *(__m512*)_ps512_erf_p4); + r_small = _mm512_fmadd_ps(r_small, s, *(__m512*)_ps512_erf_p5); + r_small = _mm512_fmadd_ps(r_small, a, a); + + __m512 r = _mm512_mask_mov_ps(r_small, mask, r_large); + return r; +} + _PS512_CONST(minus_cephes_DP1, -0.78515625f); _PS512_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS512_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); diff --git a/src/layer/x86/avx_mathfun.h b/src/layer/x86/avx_mathfun.h index 4f5ef64012b..97fbf46bf73 100644 --- a/src/layer/x86/avx_mathfun.h +++ b/src/layer/x86/avx_mathfun.h @@ -341,6 +341,54 @@ static NCNN_FORCEINLINE __m256 tanh256_ps(const __m256& x) return dst; } +_PS256_CONST(erf_threshold, 0.927734375f); + +_PS256_CONST(erf_c0, -1.72853470e-5f); +_PS256_CONST(erf_c1, 3.83197126e-4f); +_PS256_CONST(erf_c2, -3.88396438e-3f); +_PS256_CONST(erf_c3, 2.42546219e-2f); +_PS256_CONST(erf_c4, -1.06777877e-1f); +_PS256_CONST(erf_c5, -6.34846687e-1f); +_PS256_CONST(erf_c6, -1.28717512e-1f); + +_PS256_CONST(erf_p0, -5.96761703e-4f); +_PS256_CONST(erf_p1, 4.99119423e-3f); +_PS256_CONST(erf_p2, -2.67681349e-2f); +_PS256_CONST(erf_p3, 1.12819925e-1f); +_PS256_CONST(erf_p4, -3.76125336e-1f); +_PS256_CONST(erf_p5, 1.28379166e-1f); + +static NCNN_FORCEINLINE __m256 erf256_ps(const __m256& a) +{ + __m256 t = _mm256_and_ps(a, *(__m256*)_ps256_inv_sign_mask); + __m256 s = _mm256_mul_ps(a, a); + + __m256 mask = _mm256_cmp_ps(t, *(__m256*)_ps256_erf_threshold, _CMP_GT_OQ); + + __m256 r_large = _mm256_comp_fmadd_ps(*(__m256*)_ps256_erf_c0, t, *(__m256*)_ps256_erf_c1); + __m256 u = _mm256_comp_fmadd_ps(*(__m256*)_ps256_erf_c2, t, *(__m256*)_ps256_erf_c3); + r_large = _mm256_comp_fmadd_ps(r_large, s, u); + r_large = _mm256_comp_fmadd_ps(r_large, t, *(__m256*)_ps256_erf_c4); + r_large = _mm256_comp_fmadd_ps(r_large, t, *(__m256*)_ps256_erf_c5); + r_large = _mm256_comp_fmadd_ps(r_large, t, *(__m256*)_ps256_erf_c6); + r_large = _mm256_comp_fmadd_ps(r_large, t, _mm256_sub_ps(_mm256_setzero_ps(), t)); + r_large = _mm256_sub_ps(*(__m256*)_ps256_1, exp256_ps(r_large)); + + __m256 sign_mask = _mm256_and_ps(a, *(__m256*)_ps256_sign_mask); + r_large = _mm256_xor_ps(r_large, sign_mask); + + __m256 r_small = *(__m256*)_ps256_erf_p0; + r_small = _mm256_comp_fmadd_ps(r_small, s, *(__m256*)_ps256_erf_p1); + r_small = _mm256_comp_fmadd_ps(r_small, s, *(__m256*)_ps256_erf_p2); + r_small = _mm256_comp_fmadd_ps(r_small, s, *(__m256*)_ps256_erf_p3); + r_small = _mm256_comp_fmadd_ps(r_small, s, *(__m256*)_ps256_erf_p4); + r_small = _mm256_comp_fmadd_ps(r_small, s, *(__m256*)_ps256_erf_p5); + r_small = _mm256_comp_fmadd_ps(r_small, a, a); + + __m256 r = _mm256_or_ps(_mm256_and_ps(mask, r_large), _mm256_andnot_ps(mask, r_small)); + return r; +} + _PS256_CONST(minus_cephes_DP1, -0.78515625f); _PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); diff --git a/src/layer/x86/erf_x86.cpp b/src/layer/x86/erf_x86.cpp new file mode 100644 index 00000000000..5a1c396cef4 --- /dev/null +++ b/src/layer/x86/erf_x86.cpp @@ -0,0 +1,89 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "erf_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +namespace ncnn { + +Erf_x86::Erf_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int Erf_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _p = erf512_ps(_p); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = _mm512_maskz_loadu_ps(_mask, ptr); + _p = erf512_ps(_p); + _mm512_mask_storeu_ps(ptr, _mask, _p); + } + i = size; +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _p = erf256_ps(_p); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _p = erf_ps(_p); + _mm_storeu_ps(ptr, _p); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = erff(*ptr); + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/erf_x86.h b/src/layer/x86/erf_x86.h new file mode 100644 index 00000000000..f8366bf50b3 --- /dev/null +++ b/src/layer/x86/erf_x86.h @@ -0,0 +1,21 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_ERF_X86_H +#define LAYER_ERF_X86_H + +#include "erf.h" + +namespace ncnn { + +class Erf_x86 : public Erf +{ +public: + Erf_x86(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_ERF_X86_H diff --git a/src/layer/x86/gelu_x86.cpp b/src/layer/x86/gelu_x86.cpp index eac075d4bff..c4cdb36b7b9 100644 --- a/src/layer/x86/gelu_x86.cpp +++ b/src/layer/x86/gelu_x86.cpp @@ -26,20 +26,11 @@ GELU_x86::GELU_x86() int GELU_x86::create_pipeline(const Option& /*opt*/) { - if (!fast_gelu) - { - support_packing = false; - } return 0; } int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - if (!fast_gelu) - { - return GELU::forward_inplace(bottom_top_blob, opt); - } - int w = bottom_top_blob.w; int h = bottom_top_blob.h; int d = bottom_top_blob.d; @@ -57,84 +48,158 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _half512 = _mm512_set1_ps(0.5f); - __m512 _one512 = _mm512_set1_ps(1.f); - __m512 _fast1c512 = _mm512_set1_ps(0.79788452f); - __m512 _fast2c512 = _mm512_set1_ps(0.044715f); - for (; i + 15 < size; i += 16) + if (fast_gelu) { - __m512 _pLoad = _mm512_loadu_ps(ptr); + __m512 _half512 = _mm512_set1_ps(0.5f); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _fast1c512 = _mm512_set1_ps(0.79788452f); + __m512 _fast2c512 = _mm512_set1_ps(0.044715f); + for (; i + 15 < size; i += 16) + { + __m512 _pLoad = _mm512_loadu_ps(ptr); + + __m512 _cube = _mm512_mul_ps(_pLoad, _pLoad); + _cube = _mm512_mul_ps(_pLoad, _cube); - __m512 _cube = _mm512_mul_ps(_pLoad, _pLoad); - _cube = _mm512_mul_ps(_pLoad, _cube); + __m512 _blob = _mm512_mul_ps(_fast2c512, _cube); + _blob = _mm512_add_ps(_pLoad, _blob); + _blob = _mm512_mul_ps(_fast1c512, _blob); + _blob = tanh512_ps(_blob); + _blob = _mm512_add_ps(_one512, _blob); - __m512 _blob = _mm512_mul_ps(_fast2c512, _cube); - _blob = _mm512_add_ps(_pLoad, _blob); - _blob = _mm512_mul_ps(_fast1c512, _blob); - _blob = tanh512_ps(_blob); - _blob = _mm512_add_ps(_one512, _blob); + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); - _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + _mm512_storeu_ps(ptr, _blob); - _mm512_storeu_ps(ptr, _blob); + ptr += 16; + } + } + else + { + __m512 _half512 = _mm512_set1_ps(0.5f); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _inv_sqrt2_512 = _mm512_set1_ps(0.70710678f); + for (; i + 15 < size; i += 16) + { + __m512 _pLoad = _mm512_loadu_ps(ptr); + + __m512 _erf = erf512_ps(_mm512_mul_ps(_pLoad, _inv_sqrt2_512)); + __m512 _blob = _mm512_add_ps(_one512, _erf); + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); - ptr += 16; + _mm512_storeu_ps(ptr, _blob); + + ptr += 16; + } } #endif // __AVX512F__ - __m256 _half256 = _mm256_set1_ps(0.5f); - __m256 _one256 = _mm256_set1_ps(1.f); - __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); - __m256 _fast2c256 = _mm256_set1_ps(0.044715f); - for (; i + 7 < size; i += 8) + if (fast_gelu) { - __m256 _pLoad = _mm256_loadu_ps(ptr); + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); + __m256 _fast2c256 = _mm256_set1_ps(0.044715f); + for (; i + 7 < size; i += 8) + { + __m256 _pLoad = _mm256_loadu_ps(ptr); - __m256 _cube = _mm256_mul_ps(_pLoad, _pLoad); - _cube = _mm256_mul_ps(_pLoad, _cube); + __m256 _cube = _mm256_mul_ps(_pLoad, _pLoad); + _cube = _mm256_mul_ps(_pLoad, _cube); - __m256 _blob = _mm256_mul_ps(_fast2c256, _cube); - _blob = _mm256_add_ps(_pLoad, _blob); - _blob = _mm256_mul_ps(_fast1c256, _blob); - _blob = tanh256_ps(_blob); - _blob = _mm256_add_ps(_one256, _blob); + __m256 _blob = _mm256_mul_ps(_fast2c256, _cube); + _blob = _mm256_add_ps(_pLoad, _blob); + _blob = _mm256_mul_ps(_fast1c256, _blob); + _blob = tanh256_ps(_blob); + _blob = _mm256_add_ps(_one256, _blob); - _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); + _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); - _mm256_storeu_ps(ptr, _blob); + _mm256_storeu_ps(ptr, _blob); - ptr += 8; + ptr += 8; + } + } + else + { + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _inv_sqrt2_256 = _mm256_set1_ps(0.70710678f); + for (; i + 7 < size; i += 8) + { + __m256 _pLoad = _mm256_loadu_ps(ptr); + + __m256 _erf = erf256_ps(_mm256_mul_ps(_pLoad, _inv_sqrt2_256)); + __m256 _blob = _mm256_add_ps(_one256, _erf); + _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); + + _mm256_storeu_ps(ptr, _blob); + + ptr += 8; + } } #endif // __AVX__ - __m128 _half128 = _mm_set1_ps(0.5f); - __m128 _one128 = _mm_set1_ps(1.f); - __m128 _fast1c128 = _mm_set1_ps(0.79788452f); - __m128 _fast2c128 = _mm_set1_ps(0.044715f); - for (; i + 3 < size; i += 4) + if (fast_gelu) { - __m128 _pLoad = _mm_loadu_ps(ptr); + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _fast1c128 = _mm_set1_ps(0.79788452f); + __m128 _fast2c128 = _mm_set1_ps(0.044715f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = _mm_loadu_ps(ptr); - __m128 _cube = _mm_mul_ps(_pLoad, _pLoad); - _cube = _mm_mul_ps(_pLoad, _cube); + __m128 _cube = _mm_mul_ps(_pLoad, _pLoad); + _cube = _mm_mul_ps(_pLoad, _cube); - __m128 _blob = _mm_mul_ps(_fast2c128, _cube); - _blob = _mm_add_ps(_pLoad, _blob); - _blob = _mm_mul_ps(_fast1c128, _blob); - _blob = tanh_ps(_blob); - _blob = _mm_add_ps(_one128, _blob); + __m128 _blob = _mm_mul_ps(_fast2c128, _cube); + _blob = _mm_add_ps(_pLoad, _blob); + _blob = _mm_mul_ps(_fast1c128, _blob); + _blob = tanh_ps(_blob); + _blob = _mm_add_ps(_one128, _blob); - _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); - _mm_storeu_ps(ptr, _blob); + _mm_storeu_ps(ptr, _blob); - ptr += 4; + ptr += 4; + } + } + else + { + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _inv_sqrt2_128 = _mm_set1_ps(0.70710678f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = _mm_loadu_ps(ptr); + + __m128 _erf = erf_ps(_mm_mul_ps(_pLoad, _inv_sqrt2_128)); + __m128 _blob = _mm_add_ps(_one128, _erf); + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storeu_ps(ptr, _blob); + + ptr += 4; + } } #endif // __SSE2__ - for (; i < size; i++) + if (fast_gelu) + { + for (; i < size; i++) + { + *ptr = 0.5f * *ptr * (1.0f + tanhf(0.79788452f * (*ptr + 0.044715f * *ptr * *ptr * *ptr))); + + ptr++; + } + } + else { - // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))) - *ptr = 0.5f * *ptr * (1.0f + tanhf(0.79788452f * (*ptr + 0.044715f * *ptr * *ptr * *ptr))); + for (; i < size; i++) + { + *ptr = 0.5f * *ptr * (1.0f + erff(0.70710678f * *ptr)); - ptr++; + ptr++; + } } } diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index 2fc3ae3e27d..5f07cc9dc1c 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -331,6 +331,54 @@ static inline v4sf tanh_ps(const v4sf& x) return dst; } +_PS_CONST(erf_threshold, 0.927734375f); + +_PS_CONST(erf_c0, -1.72853470e-5f); +_PS_CONST(erf_c1, 3.83197126e-4f); +_PS_CONST(erf_c2, -3.88396438e-3f); +_PS_CONST(erf_c3, 2.42546219e-2f); +_PS_CONST(erf_c4, -1.06777877e-1f); +_PS_CONST(erf_c5, -6.34846687e-1f); +_PS_CONST(erf_c6, -1.28717512e-1f); + +_PS_CONST(erf_p0, -5.96761703e-4f); +_PS_CONST(erf_p1, 4.99119423e-3f); +_PS_CONST(erf_p2, -2.67681349e-2f); +_PS_CONST(erf_p3, 1.12819925e-1f); +_PS_CONST(erf_p4, -3.76125336e-1f); +_PS_CONST(erf_p5, 1.28379166e-1f); + +static NCNN_FORCEINLINE v4sf erf_ps(const v4sf& a) +{ + v4sf t = _mm_and_ps(a, *(v4sf*)_ps_inv_sign_mask); + v4sf s = _mm_mul_ps(a, a); + + v4sf mask = _mm_cmpgt_ps(t, *(v4sf*)_ps_erf_threshold); + + v4sf r_large = _mm_comp_fmadd_ps(*(v4sf*)_ps_erf_c0, t, *(v4sf*)_ps_erf_c1); + v4sf u = _mm_comp_fmadd_ps(*(v4sf*)_ps_erf_c2, t, *(v4sf*)_ps_erf_c3); + r_large = _mm_comp_fmadd_ps(r_large, s, u); + r_large = _mm_comp_fmadd_ps(r_large, t, *(v4sf*)_ps_erf_c4); + r_large = _mm_comp_fmadd_ps(r_large, t, *(v4sf*)_ps_erf_c5); + r_large = _mm_comp_fmadd_ps(r_large, t, *(v4sf*)_ps_erf_c6); + r_large = _mm_comp_fmadd_ps(r_large, t, _mm_sub_ps(_mm_setzero_ps(), t)); + r_large = _mm_sub_ps(*(v4sf*)_ps_1, exp_ps(r_large)); + + v4sf sign_mask = _mm_and_ps(a, *(v4sf*)_ps_sign_mask); + r_large = _mm_xor_ps(r_large, sign_mask); + + v4sf r_small = *(v4sf*)_ps_erf_p0; + r_small = _mm_comp_fmadd_ps(r_small, s, *(v4sf*)_ps_erf_p1); + r_small = _mm_comp_fmadd_ps(r_small, s, *(v4sf*)_ps_erf_p2); + r_small = _mm_comp_fmadd_ps(r_small, s, *(v4sf*)_ps_erf_p3); + r_small = _mm_comp_fmadd_ps(r_small, s, *(v4sf*)_ps_erf_p4); + r_small = _mm_comp_fmadd_ps(r_small, s, *(v4sf*)_ps_erf_p5); + r_small = _mm_comp_fmadd_ps(r_small, a, a); + + v4sf r = _mm_or_ps(_mm_and_ps(mask, r_large), _mm_andnot_ps(mask, r_small)); + return r; +} + _PS_CONST(minus_cephes_DP1, -0.78515625f); _PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4f); _PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8f); From 2426a8bd32a31b11c3f49f9c00a1440b1a3816d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=A8=E7=9A=84=E5=BD=B7=E5=BE=A8?= <56149058+futz12@users.noreply.github.com> Date: Wed, 18 Mar 2026 22:25:43 +0800 Subject: [PATCH 24/36] support mips elu erf gelu selu (#6607) --- src/layer/mips/elu_mips.cpp | 60 ++++++++++++++++++++ src/layer/mips/elu_mips.h | 21 +++++++ src/layer/mips/erf_mips.cpp | 56 ++++++++++++++++++ src/layer/mips/erf_mips.h | 21 +++++++ src/layer/mips/gelu_mips.cpp | 97 ++++++++++++++++++++++++++++++++ src/layer/mips/gelu_mips.h | 21 +++++++ src/layer/mips/mips_activation.h | 10 ++++ src/layer/mips/msa_mathfun.h | 48 ++++++++++++++++ src/layer/mips/selu_mips.cpp | 73 ++++++++++++++++++++++++ src/layer/mips/selu_mips.h | 21 +++++++ 10 files changed, 428 insertions(+) create mode 100644 src/layer/mips/elu_mips.cpp create mode 100644 src/layer/mips/elu_mips.h create mode 100644 src/layer/mips/erf_mips.cpp create mode 100644 src/layer/mips/erf_mips.h create mode 100644 src/layer/mips/gelu_mips.cpp create mode 100644 src/layer/mips/gelu_mips.h create mode 100644 src/layer/mips/selu_mips.cpp create mode 100644 src/layer/mips/selu_mips.h diff --git a/src/layer/mips/elu_mips.cpp b/src/layer/mips/elu_mips.cpp new file mode 100644 index 00000000000..a144161a9bb --- /dev/null +++ b/src/layer/mips/elu_mips.cpp @@ -0,0 +1,60 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "elu_mips.h" + +#if __mips_msa +#include +#include "msa_mathfun.h" +#include "mips_activation.h" +#endif // __mips_msa + +namespace ncnn { + +ELU_mips::ELU_mips() +{ +#if __mips_msa + support_packing = true; +#endif // __mips_msa +} + +int ELU_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __mips_msa + v4f32 _alpha = (v4f32)__msa_fill_w_f32(alpha); + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + _p = elu_ps(_p, _alpha); + __msa_st_w((v4i32)_p, ptr, 0); + + ptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + if (*ptr < 0.f) + *ptr = alpha * (expf(*ptr) - 1.f); + + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/mips/elu_mips.h b/src/layer/mips/elu_mips.h new file mode 100644 index 00000000000..ec9a94d857c --- /dev/null +++ b/src/layer/mips/elu_mips.h @@ -0,0 +1,21 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_ELU_MIPS_H +#define LAYER_ELU_MIPS_H + +#include "elu.h" + +namespace ncnn { + +class ELU_mips : public ELU +{ +public: + ELU_mips(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_ELU_MIPS_H diff --git a/src/layer/mips/erf_mips.cpp b/src/layer/mips/erf_mips.cpp new file mode 100644 index 00000000000..e0503d11174 --- /dev/null +++ b/src/layer/mips/erf_mips.cpp @@ -0,0 +1,56 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "erf_mips.h" + +#if __mips_msa +#include +#include "msa_mathfun.h" +#endif // __mips_msa + +namespace ncnn { + +Erf_mips::Erf_mips() +{ +#if __mips_msa + support_packing = true; +#endif +} + +int Erf_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + _p = erf_ps(_p); + __msa_st_w((v4i32)_p, ptr, 0); + + ptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + *ptr = erff(*ptr); + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/mips/erf_mips.h b/src/layer/mips/erf_mips.h new file mode 100644 index 00000000000..08951c23f7e --- /dev/null +++ b/src/layer/mips/erf_mips.h @@ -0,0 +1,21 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_ERF_MIPS_H +#define LAYER_ERF_MIPS_H + +#include "erf.h" + +namespace ncnn { + +class Erf_mips : public Erf +{ +public: + Erf_mips(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_ERF_MIPS_H diff --git a/src/layer/mips/gelu_mips.cpp b/src/layer/mips/gelu_mips.cpp new file mode 100644 index 00000000000..cafb697bd0e --- /dev/null +++ b/src/layer/mips/gelu_mips.cpp @@ -0,0 +1,97 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "gelu_mips.h" + +#if __mips_msa +#include +#include "msa_mathfun.h" +#endif // __mips_msa + +namespace ncnn { + +GELU_mips::GELU_mips() +{ +#if __mips_msa + support_packing = true; +#endif // __mips_msa +} + +int GELU_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __mips_msa + if (fast_gelu) + { + v4f32 _half = (v4f32)__msa_fill_w_f32(0.5f); + v4f32 _one = (v4f32)__msa_fill_w_f32(1.f); + v4f32 _fast1c = (v4f32)__msa_fill_w_f32(0.79788452f); + v4f32 _fast2c = (v4f32)__msa_fill_w_f32(0.044715f * 0.79788452f); + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + + v4f32 _cube = __msa_fmul_w(_p, _p); + _cube = __msa_fmul_w(_p, _cube); + v4f32 _blob = __msa_fmul_w(_fast2c, _cube); + _blob = __msa_fmadd_w(_blob, _fast1c, _p); + _blob = tanh_ps(_blob); + _blob = __msa_fadd_w(_one, _blob); + _blob = __msa_fmul_w(_half, __msa_fmul_w(_blob, _p)); + __msa_st_w((v4i32)_blob, ptr, 0); + + ptr += 4; + } + } + else + { + v4f32 _half = (v4f32)__msa_fill_w_f32(0.5f); + v4f32 _one = (v4f32)__msa_fill_w_f32(1.f); + v4f32 _inv_sqrt2 = (v4f32)__msa_fill_w_f32(0.70710678f); + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + + v4f32 _blob = __msa_fmul_w(_inv_sqrt2, _p); + _blob = erf_ps(_blob); + _blob = __msa_fadd_w(_one, _blob); + _blob = __msa_fmul_w(_half, __msa_fmul_w(_blob, _p)); + __msa_st_w((v4i32)_blob, ptr, 0); + + ptr += 4; + } + } +#endif // __mips_msa + for (; i < size; i++) + { + if (fast_gelu) + { + *ptr = 0.5f * *ptr * (1.0f + tanhf(0.79788452f * (*ptr + 0.044715f * *ptr * *ptr * *ptr))); + } + else + { + *ptr = 0.5f * *ptr * (1.0f + erff(0.70710678f * *ptr)); + } + + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/mips/gelu_mips.h b/src/layer/mips/gelu_mips.h new file mode 100644 index 00000000000..50d96606c29 --- /dev/null +++ b/src/layer/mips/gelu_mips.h @@ -0,0 +1,21 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_GELU_MIPS_H +#define LAYER_GELU_MIPS_H + +#include "gelu.h" + +namespace ncnn { + +class GELU_mips : public GELU +{ +public: + GELU_mips(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_GELU_MIPS_H diff --git a/src/layer/mips/mips_activation.h b/src/layer/mips/mips_activation.h index f43849e36d8..a860a8b9fa3 100644 --- a/src/layer/mips/mips_activation.h +++ b/src/layer/mips/mips_activation.h @@ -10,6 +10,16 @@ #include #include "msa_mathfun.h" +static inline v4f32 elu_ps(v4f32 inputs, v4f32 alphas) +{ + v4f32 _zero = (v4f32)__msa_fill_w(0); + v4f32 _one = (v4f32)__msa_fill_w_f32(1.f); + v4f32 _pos = __msa_fmax_w(inputs, _zero); + v4f32 _neg = __msa_fmin_w(inputs, _zero); + _neg = __msa_fsub_w(exp_ps(_neg), _one); + return __msa_fadd_w(_pos, __msa_fmul_w(alphas, _neg)); +} + static inline v4f32 activation_ps(v4f32 _v, int activation_type, const ncnn::Mat& activation_params) { if (activation_type == 1) diff --git a/src/layer/mips/msa_mathfun.h b/src/layer/mips/msa_mathfun.h index cc4b65fe0d6..33bca9193c9 100644 --- a/src/layer/mips/msa_mathfun.h +++ b/src/layer/mips/msa_mathfun.h @@ -238,6 +238,54 @@ static inline v4f32 tanh_ps(v4f32 x) return y; } +_MIPS_FLOAT_CONST(c_erf_threshold, 0.927734375f); +_MIPS_FLOAT_CONST(c_erf_c0, -1.72853470e-5f); +_MIPS_FLOAT_CONST(c_erf_c1, 3.83197126e-4f); +_MIPS_FLOAT_CONST(c_erf_c2, -3.88396438e-3f); +_MIPS_FLOAT_CONST(c_erf_c3, 2.42546219e-2f); +_MIPS_FLOAT_CONST(c_erf_c4, -1.06777877e-1f); +_MIPS_FLOAT_CONST(c_erf_c5, -6.34846687e-1f); +_MIPS_FLOAT_CONST(c_erf_c6, -1.28717512e-1f); +_MIPS_FLOAT_CONST(c_erf_p0, -5.96761703e-4f); +_MIPS_FLOAT_CONST(c_erf_p1, 4.99119423e-3f); +_MIPS_FLOAT_CONST(c_erf_p2, -2.67681349e-2f); +_MIPS_FLOAT_CONST(c_erf_p3, 1.12819925e-1f); +_MIPS_FLOAT_CONST(c_erf_p4, -3.76125336e-1f); +_MIPS_FLOAT_CONST(c_erf_p5, 1.28379166e-1f); + +static inline v4f32 erf_ps(v4f32 a) +{ + v4f32 one = (v4f32)__msa_fill_w(c_1.i); + + v4f32 t = (v4f32)__msa_bclri_w((v4u32)a, 31); + v4f32 s = __msa_fmul_w(a, a); + + v4i32 mask = __msa_fclt_w((v4f32)__msa_fill_w(c_erf_threshold.i), t); + + v4f32 r1 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_c1.i), (v4f32)__msa_fill_w(c_erf_c0.i), t); + v4f32 u = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_c3.i), (v4f32)__msa_fill_w(c_erf_c2.i), t); + r1 = __msa_fmadd_w(u, r1, s); + r1 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_c4.i), r1, t); + r1 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_c5.i), r1, t); + r1 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_c6.i), r1, t); + v4f32 neg_t = (v4f32)__msa_bnegi_w((v4u32)t, 31); + r1 = __msa_fmadd_w(neg_t, r1, t); + r1 = __msa_fsub_w(one, exp_ps(r1)); + r1 = (v4f32)__msa_binsli_w((v4u32)r1, (v4u32)a, 0); + + v4f32 r2 = (v4f32)__msa_fill_w(c_erf_p0.i); + r2 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_p1.i), r2, s); + r2 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_p2.i), r2, s); + r2 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_p3.i), r2, s); + r2 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_p4.i), r2, s); + r2 = __msa_fmadd_w((v4f32)__msa_fill_w(c_erf_p5.i), r2, s); + r2 = __msa_fmadd_w(a, r2, a); + + v4f32 r = (v4f32)__msa_bsel_v((v16u8)mask, (v16u8)r2, (v16u8)r1); + + return r; +} + static inline v4f32 pow_ps(v4f32 a, v4f32 b) { // pow(x, m) = exp(m * log(x)) diff --git a/src/layer/mips/selu_mips.cpp b/src/layer/mips/selu_mips.cpp new file mode 100644 index 00000000000..609c33bfa8f --- /dev/null +++ b/src/layer/mips/selu_mips.cpp @@ -0,0 +1,73 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "selu_mips.h" + +#if __mips_msa +#include +#include "msa_mathfun.h" +#endif // __mips_msa + +namespace ncnn { + +SELU_mips::SELU_mips() +{ +#if __mips_msa + support_packing = true; +#endif // __mips_msa +} + +int SELU_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int channels = bottom_top_blob.c; + int elempack = bottom_top_blob.elempack; + int size = w * h * d * elempack; + float alphaxlambda = alpha * lambda; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; +#if __mips_msa + v4f32 _one = (v4f32)__msa_fill_w_f32(1.f); + v4f32 _zero = (v4f32)__msa_fill_w(0); + v4f32 _alphaxlambda = (v4f32)__msa_fill_w_f32(alphaxlambda); + v4f32 _lambda = (v4f32)__msa_fill_w_f32(lambda); + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4i32_w _lemask = __msa_fcle_w(_p, _zero); + + v4f32 _nps = exp_ps(_p); + _nps = __msa_fsub_w(_nps, _one); + _nps = __msa_fmul_w(_nps, _alphaxlambda); + + _p = __msa_fmul_w(_p, _lambda); + + _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_nps); + __msa_st_w((v4i32)_p, ptr, 0); + + ptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + if (*ptr < 0.f) + *ptr = (expf(*ptr) - 1.f) * alphaxlambda; + else + *ptr *= lambda; + + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/mips/selu_mips.h b/src/layer/mips/selu_mips.h new file mode 100644 index 00000000000..79f390216d4 --- /dev/null +++ b/src/layer/mips/selu_mips.h @@ -0,0 +1,21 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_SELU_MIPS_H +#define LAYER_SELU_MIPS_H + +#include "selu.h" + +namespace ncnn { + +class SELU_mips : public SELU +{ +public: + SELU_mips(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_SELU_MIPS_H From 1f0c2a8b93a280635d302d5995ea15cd67252522 Mon Sep 17 00:00:00 2001 From: NKID00 Date: Wed, 25 Mar 2026 14:06:48 +0800 Subject: [PATCH 25/36] fix: subgroup float16 extension should be required (#6615) --- src/layer/vulkan/shader/reduction.comp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/layer/vulkan/shader/reduction.comp b/src/layer/vulkan/shader/reduction.comp index 886d0af4889..3dfe2867e2d 100644 --- a/src/layer/vulkan/shader/reduction.comp +++ b/src/layer/vulkan/shader/reduction.comp @@ -6,6 +6,9 @@ #if ncnn_subgroup_arithmetic #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_arithmetic : enable +#if NCNN_fp16_storage +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif #endif layout(constant_id = 0) const int op = 0; From e1d413dbb0a171f0812cafc280576a7206b71a28 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 27 Mar 2026 17:13:02 +0800 Subject: [PATCH 26/36] Bump codecov/codecov-action from 5 to 6 (#6616) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 5 to 6. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v5...v6) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/test-coverage.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index e67f23c0b77..ffaeab8be2b 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -56,7 +56,7 @@ jobs: lcov --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -138,7 +138,7 @@ jobs: lcov --ignore-errors inconsistent -r lcov.info '*/build-openmp/*' -o lcov.info lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -170,7 +170,7 @@ jobs: lcov --ignore-errors inconsistent -r lcov.info '*/build/*' -o lcov.info lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -222,7 +222,7 @@ jobs: lcov --ignore-errors inconsistent -r lcov.info '*/build/*' -o lcov.info lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -282,7 +282,7 @@ jobs: lcov --ignore-errors inconsistent -r lcov.info '*/build/*' -o lcov.info lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -327,7 +327,7 @@ jobs: lcov --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -358,7 +358,7 @@ jobs: lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -416,7 +416,7 @@ jobs: lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true @@ -575,7 +575,7 @@ jobs: lcov --ignore-errors inconsistent --list lcov.info - name: codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true From 191d239e7dc61904ff464a332158d6be22533e66 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 27 Mar 2026 18:25:56 +0800 Subject: [PATCH 27/36] x86 gemm bf16s optimization with avx512bf16 (#6609) --- src/layer/x86/gemm_bf16s.h | 9364 +++++++++++++++++++++---- src/layer/x86/gemm_x86.cpp | 407 +- src/layer/x86/gemm_x86.h | 1 + src/layer/x86/gemm_x86_avx512bf16.cpp | 43 + src/layer/x86/x86_usability.h | 121 +- 5 files changed, 8331 insertions(+), 1605 deletions(-) create mode 100644 src/layer/x86/gemm_x86_avx512bf16.cpp diff --git a/src/layer/x86/gemm_bf16s.h b/src/layer/x86/gemm_bf16s.h index 2aa5f25fe39..4f71ac5ca17 100644 --- a/src/layer/x86/gemm_bf16s.h +++ b/src/layer/x86/gemm_bf16s.h @@ -1,12 +1,29 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void pack_A_tile_bf16_avx512bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_bf16_avx512bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_bf16_avx512bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_bf16_avx512bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void gemm_transB_packed_tile_bf16s_avx512bf16(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +static void pack_A_tile_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pack_A_tile_bf16_avx512bf16(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_A_tile_bf16 %d %d %d %d", i, max_ii, k, max_kk); const int elempack = A.elempack; const size_t A_hstep = A.dims == 3 ? A.cstep : (size_t)A.w; - float* pp = AT; + unsigned short* pp = (unsigned short*)AT; int ii = 0; #if __SSE2__ @@ -14,26 +31,52 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i #if __AVX512F__ for (; ii + 15 < max_ii; ii += 16) { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 16; - - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + __m512i _idx = _mm512_set_epi16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _p = _mm512_loadu_si512((const __m512i*)p0); + _p = _mm512_permutexvar_epi16(_idx, _p); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm512_store_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); pp += 16; p0 += 16; } } if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 8; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 8) * A_hstep + k * 8; + const unsigned short* p1 = p0 + A_hstep * 8; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + __m512i _idx = _mm512_set_epi16(31, 23, 30, 22, 29, 21, 28, 20, 27, 19, 26, 18, 25, 17, 24, 16, 15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _a = _mm256_loadu_si256((const __m256i*)p0); + __m256i _b = _mm256_loadu_si256((const __m256i*)p1); + __m512i _ab = combine8x2_epi32(_a, _b); + __m512i _p = _mm512_permutexvar_epi16(_idx, _ab); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 16; + p1 += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm256_store_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - _mm256_store_ps(pp + 8, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); + _mm_storeu_si128((__m128i*)(pp + 8), _mm_loadu_si128((const __m128i*)p1)); pp += 16; p0 += 8; p1 += 8; @@ -41,17 +84,39 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 4; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k * 4; - const unsigned short* p2 = (const unsigned short*)A + (i + ii + 8) * A_hstep + k * 4; - const unsigned short* p3 = (const unsigned short*)A + (i + ii + 12) * A_hstep + k * 4; + const unsigned short* p1 = p0 + A_hstep * 4; + const unsigned short* p2 = p0 + A_hstep * 8; + const unsigned short* p3 = p0 + A_hstep * 12; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _a0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _a1 = _mm_loadu_si128((const __m128i*)p1); + __m128i _a2 = _mm_loadu_si128((const __m128i*)p2); + __m128i _a3 = _mm_loadu_si128((const __m128i*)p3); + __m128i _t0 = _mm_unpacklo_epi16(_a0, _mm_srli_si128(_a0, 8)); + __m128i _t1 = _mm_unpacklo_epi16(_a1, _mm_srli_si128(_a1, 8)); + __m128i _t2 = _mm_unpacklo_epi16(_a2, _mm_srli_si128(_a2, 8)); + __m128i _t3 = _mm_unpacklo_epi16(_a3, _mm_srli_si128(_a3, 8)); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + _mm_storeu_si128((__m128i*)(pp + 16), _t2); + _mm_storeu_si128((__m128i*)(pp + 24), _t3); + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - _mm_store_ps(pp + 8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2))); - _mm_store_ps(pp + 12, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); + _mm_storel_epi64((__m128i*)(pp + 4), _mm_loadl_epi64((const __m128i*)p1)); + _mm_storel_epi64((__m128i*)(pp + 8), _mm_loadl_epi64((const __m128i*)p2)); + _mm_storel_epi64((__m128i*)(pp + 12), _mm_loadl_epi64((const __m128i*)p3)); pp += 16; p0 += 4; p1 += 4; @@ -61,138 +126,79 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; - const unsigned short* p2 = (const unsigned short*)A + (i + ii + 2) * A_hstep + k; - const unsigned short* p3 = (const unsigned short*)A + (i + ii + 3) * A_hstep + k; - const unsigned short* p4 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k; - const unsigned short* p5 = (const unsigned short*)A + (i + ii + 5) * A_hstep + k; - const unsigned short* p6 = (const unsigned short*)A + (i + ii + 6) * A_hstep + k; - const unsigned short* p7 = (const unsigned short*)A + (i + ii + 7) * A_hstep + k; - const unsigned short* p8 = (const unsigned short*)A + (i + ii + 8) * A_hstep + k; - const unsigned short* p9 = (const unsigned short*)A + (i + ii + 9) * A_hstep + k; - const unsigned short* pa = (const unsigned short*)A + (i + ii + 10) * A_hstep + k; - const unsigned short* pb = (const unsigned short*)A + (i + ii + 11) * A_hstep + k; - const unsigned short* pc = (const unsigned short*)A + (i + ii + 12) * A_hstep + k; - const unsigned short* pd = (const unsigned short*)A + (i + ii + 13) * A_hstep + k; - const unsigned short* pe = (const unsigned short*)A + (i + ii + 14) * A_hstep + k; - const unsigned short* pf = (const unsigned short*)A + (i + ii + 15) * A_hstep + k; + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(A_hstep)); int kk = 0; - for (; kk + 15 < max_kk; kk += 16) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p1)); - __m512 _r2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p2)); - __m512 _r3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p3)); - __m512 _r4 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p4)); - __m512 _r5 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p5)); - __m512 _r6 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p6)); - __m512 _r7 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p7)); - __m512 _r8 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p8)); - __m512 _r9 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p9)); - __m512 _ra = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pa)); - __m512 _rb = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pb)); - __m512 _rc = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pc)); - __m512 _rd = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pd)); - __m512 _re = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pe)); - __m512 _rf = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pf)); - transpose16x16_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - _mm512_store_ps(pp + 16 * 4, _r4); - _mm512_store_ps(pp + 16 * 5, _r5); - _mm512_store_ps(pp + 16 * 6, _r6); - _mm512_store_ps(pp + 16 * 7, _r7); - _mm512_store_ps(pp + 16 * 8, _r8); - _mm512_store_ps(pp + 16 * 9, _r9); - _mm512_store_ps(pp + 16 * 10, _ra); - _mm512_store_ps(pp + 16 * 11, _rb); - _mm512_store_ps(pp + 16 * 12, _rc); - _mm512_store_ps(pp + 16 * 13, _rd); - _mm512_store_ps(pp + 16 * 14, _re); - _mm512_store_ps(pp + 16 * 15, _rf); - pp += 256; - p0 += 16; - p1 += 16; - p2 += 16; - p3 += 16; - p4 += 16; - p5 += 16; - p6 += 16; - p7 += 16; - p8 += 16; - p9 += 16; - pa += 16; - pb += 16; - pc += 16; - pd += 16; - pe += 16; - pf += 16; + __m512i _p = _mm512_i32gather_epi32(_vindex, (const int*)p0, sizeof(unsigned short)); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 2; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); - pp[4] = bfloat16_to_float32(p4[0]); - pp[5] = bfloat16_to_float32(p5[0]); - pp[6] = bfloat16_to_float32(p6[0]); - pp[7] = bfloat16_to_float32(p7[0]); - pp[8] = bfloat16_to_float32(p8[0]); - pp[9] = bfloat16_to_float32(p9[0]); - pp[10] = bfloat16_to_float32(pa[0]); - pp[11] = bfloat16_to_float32(pb[0]); - pp[12] = bfloat16_to_float32(pc[0]); - pp[13] = bfloat16_to_float32(pd[0]); - pp[14] = bfloat16_to_float32(pe[0]); - pp[15] = bfloat16_to_float32(pf[0]); + __m512i _p = _mm512_i32gather_epi32(_vindex, (const int*)p0, sizeof(unsigned short)); + __m256i _p16 = _mm512_cvtepi32_epi16(_p); + _mm256_storeu_si256((__m256i*)pp, _p16); pp += 16; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; - pc++; - pd++; - pe++; - pf++; } } } #endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 8; - - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - _mm256_store_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpackhi_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + pp += 16; + p0 += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); pp += 8; p0 += 8; } } if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 4; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k * 4; + const unsigned short* p1 = p0 + A_hstep * 4; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _a0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _b0 = _mm_loadu_si128((const __m128i*)p1); + __m128i _t0 = _mm_unpacklo_epi16(_a0, _mm_srli_si128(_a0, 8)); + __m128i _t1 = _mm_unpacklo_epi16(_b0, _mm_srli_si128(_b0, 8)); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + pp += 16; + p0 += 8; + p1 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); + _mm_storel_epi64((__m128i*)(pp + 4), _mm_loadl_epi64((const __m128i*)p1)); pp += 8; p0 += 4; p1 += 4; @@ -200,131 +206,118 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; - const unsigned short* p2 = (const unsigned short*)A + (i + ii + 2) * A_hstep + k; - const unsigned short* p3 = (const unsigned short*)A + (i + ii + 3) * A_hstep + k; - const unsigned short* p4 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k; - const unsigned short* p5 = (const unsigned short*)A + (i + ii + 5) * A_hstep + k; - const unsigned short* p6 = (const unsigned short*)A + (i + ii + 6) * A_hstep + k; - const unsigned short* p7 = (const unsigned short*)A + (i + ii + 7) * A_hstep + k; +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32((int)A_hstep)); +#endif int kk = 0; - for (; kk + 7 < max_kk; kk += 8) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p2)); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p3)); - __m256 _r4 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p4)); - __m256 _r5 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p5)); - __m256 _r6 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p6)); - __m256 _r7 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p7)); - transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - pp += 64; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; - p4 += 8; - p5 += 8; - p6 += 8; - p7 += 8; + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(unsigned short)); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 16; + p0 += 2; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); - pp[4] = bfloat16_to_float32(p4[0]); - pp[5] = bfloat16_to_float32(p5[0]); - pp[6] = bfloat16_to_float32(p6[0]); - pp[7] = bfloat16_to_float32(p7[0]); +#if __AVX2__ + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(unsigned short)); + __m128i _p16 = _mm256_comp_cvtepi32_epi16(_p); + _mm_storeu_si128((__m128i*)pp, _p16); +#else + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[A_hstep * 4]; + pp[5] = p0[A_hstep * 5]; + pp[6] = p0[A_hstep * 6]; + pp[7] = p0[A_hstep * 7]; +#endif pp += 8; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; } } } #endif // __AVX__ for (; ii + 3 < max_ii; ii += 4) { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 4; - - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _t0 = _mm_unpacklo_epi16(_r0, _mm_srli_si128(_r0, 8)); + __m128i _t1 = _mm_unpackhi_epi16(_mm_slli_si128(_r0, 8), _r0); + (void)_t1; + _mm_storeu_si128((__m128i*)pp, _t0); + pp += 8; + p0 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; - const unsigned short* p2 = (const unsigned short*)A + (i + ii + 2) * A_hstep + k; - const unsigned short* p3 = (const unsigned short*)A + (i + ii + 3) * A_hstep + k; + const unsigned short* p1 = p0 + A_hstep * 1; + const unsigned short* p2 = p0 + A_hstep * 2; + const unsigned short* p3 = p0 + A_hstep * 3; int kk = 0; -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p2)); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p3)); - transpose8x4_ps(_r0, _r1, _r2, _r3); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); - _mm256_store_ps(pp + 16, _r2); - _mm256_store_ps(pp + 24, _r3); - pp += 32; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; } -#endif // __AVX__ +#else // __AVX512BF16__ for (; kk + 3 < max_kk; kk += 4) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _r2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2)); - __m128 _r3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3)); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4, _r1); - _mm_store_ps(pp + 8, _r2); - _mm_store_ps(pp + 12, _r3); + __m128i _r0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)p1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)p2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)p3); + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_t0, _t1); + _r1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); pp += 16; p0 += 4; p1 += 4; p2 += 4; p3 += 4; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; pp += 4; p0++; p1++; @@ -336,43 +329,28 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + const unsigned short* p1 = p0 + A_hstep; + // if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; - int kk = 0; -#if __SSE2__ -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - transpose8x2_ps(_r0, _r1); - _mm256_storeu_ps(pp, _r0); - _mm256_storeu_ps(pp + 8, _r1); - pp += 16; - p0 += 8; - p1 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); - __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); - _mm_store_ps(pp, _tmp0); - _mm_store_ps(pp + 4, _tmp1); - pp += 8; - p0 += 4; - p1 += 4; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; } -#endif // __SSE2__ +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); + pp[0] = p0[0]; + pp[1] = p1[0]; pp += 2; p0++; p1++; @@ -381,30 +359,23 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } for (; ii < max_ii; ii += 1) { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + // if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - int kk = 0; -#if __SSE2__ -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - _mm256_storeu_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - pp += 8; - p0 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - _mm_storeu_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - pp += 4; - p0 += 4; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += 2; } -#endif // __SSE2__ +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); + pp[0] = p0[0]; pp += 1; p0++; } @@ -412,12 +383,21 @@ static void pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, i } } -static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +static void transpose_pack_A_tile_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + transpose_pack_A_tile_bf16_avx512bf16(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_A_tile_bf16 %d %d %d %d", i, max_ii, k, max_kk); const int elempack = A.elempack; const size_t A_hstep = A.dims == 3 ? A.cstep : (size_t)A.w; - float* pp = AT; + unsigned short* pp = (unsigned short*)AT; int ii = 0; #if __SSE2__ @@ -425,133 +405,227 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #if __AVX512F__ for (; ii + 15 < max_ii; ii += 16) { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 1)); - __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 2)); - __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 3)); - __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 4)); - __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 5)); - __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 6)); - __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 7)); - __m256i _r8 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 8)); - __m256i _r9 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 9)); - __m256i _ra = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 10)); - __m256i _rb = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 11)); - __m256i _rc = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 12)); - __m256i _rd = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 13)); - __m256i _re = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 14)); - __m256i _rf = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 15)); - transpose16x16_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm512_store_ps(pp, bfloat2float_avx512(_r0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_r1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_r2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_r3)); - _mm512_store_ps(pp + 16 * 4, bfloat2float_avx512(_r4)); - _mm512_store_ps(pp + 16 * 5, bfloat2float_avx512(_r5)); - _mm512_store_ps(pp + 16 * 6, bfloat2float_avx512(_r6)); - _mm512_store_ps(pp + 16 * 7, bfloat2float_avx512(_r7)); - _mm512_store_ps(pp + 16 * 8, bfloat2float_avx512(_r8)); - _mm512_store_ps(pp + 16 * 9, bfloat2float_avx512(_r9)); - _mm512_store_ps(pp + 16 * 10, bfloat2float_avx512(_ra)); - _mm512_store_ps(pp + 16 * 11, bfloat2float_avx512(_rb)); - _mm512_store_ps(pp + 16 * 12, bfloat2float_avx512(_rc)); - _mm512_store_ps(pp + 16 * 13, bfloat2float_avx512(_rd)); - _mm512_store_ps(pp + 16 * 14, bfloat2float_avx512(_re)); - _mm512_store_ps(pp + 16 * 15, bfloat2float_avx512(_rf)); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _r4 = _mm512_loadu_si512((const __m512i*)(p0 + 128)); + __m512i _r5 = _mm512_loadu_si512((const __m512i*)(p0 + 160)); + __m512i _r6 = _mm512_loadu_si512((const __m512i*)(p0 + 192)); + __m512i _r7 = _mm512_loadu_si512((const __m512i*)(p0 + 224)); + + __m512i w0 = _mm512_shuffle_i64x2(_r0, _r1, 0x44); + __m512i w1 = _mm512_shuffle_i64x2(_r0, _r1, 0xEE); + __m512i w2 = _mm512_shuffle_i64x2(_r2, _r3, 0x44); + __m512i w3 = _mm512_shuffle_i64x2(_r2, _r3, 0xEE); + __m512i w4 = _mm512_shuffle_i64x2(_r4, _r5, 0x44); + __m512i w5 = _mm512_shuffle_i64x2(_r4, _r5, 0xEE); + __m512i w6 = _mm512_shuffle_i64x2(_r6, _r7, 0x44); + __m512i w7 = _mm512_shuffle_i64x2(_r6, _r7, 0xEE); + +#if __AVX512BF16__ + __m512i a0 = _mm512_unpacklo_epi32(w0, w1); + __m512i a1 = _mm512_unpackhi_epi32(w0, w1); + __m512i a2 = _mm512_unpacklo_epi32(w2, w3); + __m512i a3 = _mm512_unpackhi_epi32(w2, w3); + __m512i a4 = _mm512_unpacklo_epi32(w4, w5); + __m512i a5 = _mm512_unpackhi_epi32(w4, w5); + __m512i a6 = _mm512_unpacklo_epi32(w6, w7); + __m512i a7 = _mm512_unpackhi_epi32(w6, w7); + + __m512i b0 = _mm512_unpacklo_epi64(a0, a2); + __m512i b1 = _mm512_unpackhi_epi64(a0, a2); + __m512i b2 = _mm512_unpacklo_epi64(a1, a3); + __m512i b3 = _mm512_unpackhi_epi64(a1, a3); + __m512i b4 = _mm512_unpacklo_epi64(a4, a6); + __m512i b5 = _mm512_unpackhi_epi64(a4, a6); + __m512i b6 = _mm512_unpacklo_epi64(a5, a7); + __m512i b7 = _mm512_unpackhi_epi64(a5, a7); + + __m512i idx_l = _mm512_set_epi32(27, 26, 19, 18, 25, 24, 17, 16, 11, 10, 3, 2, 9, 8, 1, 0); + __m512i idx_r = _mm512_set_epi32(31, 30, 23, 22, 29, 28, 21, 20, 15, 14, 7, 6, 13, 12, 5, 4); + + __m512i _p0 = _mm512_permutex2var_epi32(b0, idx_l, b4); + __m512i _p1 = _mm512_permutex2var_epi32(b1, idx_l, b5); + __m512i _p2 = _mm512_permutex2var_epi32(b2, idx_l, b6); + __m512i _p3 = _mm512_permutex2var_epi32(b3, idx_l, b7); + __m512i _p4 = _mm512_permutex2var_epi32(b0, idx_r, b4); + __m512i _p5 = _mm512_permutex2var_epi32(b1, idx_r, b5); + __m512i _p6 = _mm512_permutex2var_epi32(b2, idx_r, b6); + __m512i _p7 = _mm512_permutex2var_epi32(b3, idx_r, b7); +#else // __AVX512BF16__ + __m512i a0 = _mm512_unpacklo_epi16(w0, w1); + __m512i a1 = _mm512_unpackhi_epi16(w0, w1); + __m512i a2 = _mm512_unpacklo_epi16(w2, w3); + __m512i a3 = _mm512_unpackhi_epi16(w2, w3); + __m512i a4 = _mm512_unpacklo_epi16(w4, w5); + __m512i a5 = _mm512_unpackhi_epi16(w4, w5); + __m512i a6 = _mm512_unpacklo_epi16(w6, w7); + __m512i a7 = _mm512_unpackhi_epi16(w6, w7); + + __m512i b0 = _mm512_unpacklo_epi32(a0, a2); + __m512i b1 = _mm512_unpackhi_epi32(a0, a2); + __m512i b2 = _mm512_unpacklo_epi32(a1, a3); + __m512i b3 = _mm512_unpackhi_epi32(a1, a3); + __m512i b4 = _mm512_unpacklo_epi32(a4, a6); + __m512i b5 = _mm512_unpackhi_epi32(a4, a6); + __m512i b6 = _mm512_unpacklo_epi32(a5, a7); + __m512i b7 = _mm512_unpackhi_epi32(a5, a7); + + __m512i c0 = _mm512_unpacklo_epi64(b0, b4); + __m512i c1 = _mm512_unpackhi_epi64(b0, b4); + __m512i c2 = _mm512_unpacklo_epi64(b1, b5); + __m512i c3 = _mm512_unpackhi_epi64(b1, b5); + __m512i c4 = _mm512_unpacklo_epi64(b2, b6); + __m512i c5 = _mm512_unpackhi_epi64(b2, b6); + __m512i c6 = _mm512_unpacklo_epi64(b3, b7); + __m512i c7 = _mm512_unpackhi_epi64(b3, b7); + + __m512i idx_lo = _mm512_set_epi32(27, 19, 26, 18, 25, 17, 24, 16, 11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx_hi = _mm512_set_epi32(31, 23, 30, 22, 29, 21, 28, 20, 15, 7, 14, 6, 13, 5, 12, 4); + + __m512i _p0 = _mm512_permutex2var_epi32(c0, idx_lo, c1); // col 0,1 + __m512i _p1 = _mm512_permutex2var_epi32(c2, idx_lo, c3); // col 2,3 + __m512i _p2 = _mm512_permutex2var_epi32(c4, idx_lo, c5); // col 4,5 + __m512i _p3 = _mm512_permutex2var_epi32(c6, idx_lo, c7); // col 6,7 + __m512i _p4 = _mm512_permutex2var_epi32(c0, idx_hi, c1); // col 8,9 + __m512i _p5 = _mm512_permutex2var_epi32(c2, idx_hi, c3); // col A,B + __m512i _p6 = _mm512_permutex2var_epi32(c4, idx_hi, c5); // col C,D + __m512i _p7 = _mm512_permutex2var_epi32(c6, idx_hi, c7); // col E,F +#endif // __AVX512BF16__ + + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); + _mm512_storeu_si512((__m512i*)(pp + 64), _p2); + _mm512_storeu_si512((__m512i*)(pp + 96), _p3); + _mm512_storeu_si512((__m512i*)(pp + 128), _p4); + _mm512_storeu_si512((__m512i*)(pp + 160), _p5); + _mm512_storeu_si512((__m512i*)(pp + 192), _p6); + _mm512_storeu_si512((__m512i*)(pp + 224), _p7); pp += 256; p0 += A_hstep * 16; } } if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); - __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 4)); - __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 5)); - __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 6)); - __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 7)); - __m128i _r8 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 8)); - __m128i _r9 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 9)); - __m128i _ra = _mm_loadu_si128((const __m128i*)(p0 + 8 * 10)); - __m128i _rb = _mm_loadu_si128((const __m128i*)(p0 + 8 * 11)); - __m128i _rc = _mm_loadu_si128((const __m128i*)(p0 + 8 * 12)); - __m128i _rd = _mm_loadu_si128((const __m128i*)(p0 + 8 * 13)); - __m128i _re = _mm_loadu_si128((const __m128i*)(p0 + 8 * 14)); - __m128i _rf = _mm_loadu_si128((const __m128i*)(p0 + 8 * 15)); - transpose8x16_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm512_store_ps(pp, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _r1, 1))); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_r2), _r3, 1))); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_r4), _r5, 1))); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_r6), _r7, 1))); - _mm512_store_ps(pp + 16 * 4, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_r8), _r9, 1))); - _mm512_store_ps(pp + 16 * 5, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_ra), _rb, 1))); - _mm512_store_ps(pp + 16 * 6, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_rc), _rd, 1))); - _mm512_store_ps(pp + 16 * 7, bfloat2float_avx512(_mm256_inserti128_si256(_mm256_castsi128_si256(_re), _rf, 1))); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); +#if __AVX512BF16__ + __m512i idx0 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i idx1 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 29, 25, 21, 17, 13, 9, 5, 1); + __m512i idx2 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i idx3 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 31, 27, 23, 19, 15, 11, 7, 3); + + __m512i lo0 = _mm512_permutex2var_epi32(_r0, idx0, _r1); + __m512i lo1 = _mm512_permutex2var_epi32(_r0, idx1, _r1); + __m512i lo2 = _mm512_permutex2var_epi32(_r0, idx2, _r1); + __m512i lo3 = _mm512_permutex2var_epi32(_r0, idx3, _r1); + + __m512i hi0 = _mm512_permutex2var_epi32(_r2, idx0, _r3); + __m512i hi1 = _mm512_permutex2var_epi32(_r2, idx1, _r3); + __m512i hi2 = _mm512_permutex2var_epi32(_r2, idx2, _r3); + __m512i hi3 = _mm512_permutex2var_epi32(_r2, idx3, _r3); + + __m512i _p0 = _mm512_inserti64x4(lo0, _mm512_castsi512_si256(hi0), 1); + __m512i _p1 = _mm512_inserti64x4(lo1, _mm512_castsi512_si256(hi1), 1); + __m512i _p2 = _mm512_inserti64x4(lo2, _mm512_castsi512_si256(hi2), 1); + __m512i _p3 = _mm512_inserti64x4(lo3, _mm512_castsi512_si256(hi3), 1); +#else // __AVX512BF16__ + __m512i id0 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i id1 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 29, 25, 21, 17, 13, 9, 5, 1); + __m512i id2 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i id3 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 31, 27, 23, 19, 15, 11, 7, 3); + + __m512i p0_lo = _mm512_permutex2var_epi32(_r0, id0, _r1); + __m512i p1_lo = _mm512_permutex2var_epi32(_r0, id1, _r1); + __m512i p2_lo = _mm512_permutex2var_epi32(_r0, id2, _r1); + __m512i p3_lo = _mm512_permutex2var_epi32(_r0, id3, _r1); + + __m512i p0_hi = _mm512_permutex2var_epi32(_r2, id0, _r3); + __m512i p1_hi = _mm512_permutex2var_epi32(_r2, id1, _r3); + __m512i p2_hi = _mm512_permutex2var_epi32(_r2, id2, _r3); + __m512i p3_hi = _mm512_permutex2var_epi32(_r2, id3, _r3); + + __m512i cp0 = _mm512_inserti64x4(p0_lo, _mm512_castsi512_si256(p0_hi), 1); + __m512i cp1 = _mm512_inserti64x4(p1_lo, _mm512_castsi512_si256(p1_hi), 1); + __m512i cp2 = _mm512_inserti64x4(p2_lo, _mm512_castsi512_si256(p2_hi), 1); + __m512i cp3 = _mm512_inserti64x4(p3_lo, _mm512_castsi512_si256(p3_hi), 1); + + __m512i shuf = _mm512_set4_epi32(0x0f0e0b0a, 0x07060302, 0x0d0c0908, 0x05040100); + __m512i pq = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + __m512i s0 = _mm512_shuffle_epi8(cp0, shuf); + __m512i s1 = _mm512_shuffle_epi8(cp1, shuf); + __m512i s2 = _mm512_shuffle_epi8(cp2, shuf); + __m512i s3 = _mm512_shuffle_epi8(cp3, shuf); + + __m512i _p0 = _mm512_permutexvar_epi64(pq, s0); + __m512i _p1 = _mm512_permutexvar_epi64(pq, s1); + __m512i _p2 = _mm512_permutexvar_epi64(pq, s2); + __m512i _p3 = _mm512_permutexvar_epi64(pq, s3); +#endif // __AVX512BF16__ + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); + _mm512_storeu_si512((__m512i*)(pp + 64), _p2); + _mm512_storeu_si512((__m512i*)(pp + 96), _p3); pp += 128; p0 += A_hstep * 8; } } if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _a1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _a2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _a3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - __m128i _b0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 4)); - __m128i _b1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 5)); - __m128i _b2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 6)); - __m128i _b3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 7)); - __m128i _c0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 8)); - __m128i _c1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 9)); - __m128i _c2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 10)); - __m128i _c3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 11)); - __m128i _d0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 12)); - __m128i _d1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 13)); - __m128i _d2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 14)); - __m128i _d3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 15)); - transpose8x4_epi16(_a0, _a1, _a2, _a3); - transpose8x4_epi16(_b0, _b1, _b2, _b3); - transpose8x4_epi16(_c0, _c1, _c2, _c3); - transpose8x4_epi16(_d0, _d1, _d2, _d3); - __m256i _col0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpacklo_epi64(_a0, _b0)), _mm_unpacklo_epi64(_c0, _d0), 1); - __m256i _col1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpackhi_epi64(_a0, _b0)), _mm_unpackhi_epi64(_c0, _d0), 1); - __m256i _col2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpacklo_epi64(_a1, _b1)), _mm_unpacklo_epi64(_c1, _d1), 1); - __m256i _col3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpackhi_epi64(_a1, _b1)), _mm_unpackhi_epi64(_c1, _d1), 1); - _mm512_store_ps(pp, bfloat2float_avx512(_col0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_col1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_col2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_col3)); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); +#if __AVX512BF16__ + __m512i idx_lo = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30); + __m512i idx_hi = _mm512_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31); + __m512i _p0 = _mm512_permutex2var_epi32(_r0, idx_lo, _r1); + __m512i _p1 = _mm512_permutex2var_epi32(_r0, idx_hi, _r1); +#else // __AVX512BF16__ + __m512i idx_lo = _mm512_set_epi16(61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i idx_hi = _mm512_set_epi16(63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i _p0 = _mm512_permutex2var_epi16(_r0, idx_lo, _r1); + __m512i _p1 = _mm512_permutex2var_epi16(_r0, idx_hi, _r1); +#endif // __AVX512BF16__ + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); pp += 64; p0 += A_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + A_hstep)); + transpose16x2_epi16(_r0, _r1); + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + pp += 32; + p0 += A_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm512_store_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); pp += 16; p0 += A_hstep; } @@ -560,31 +634,36 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + +#if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); - __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 2)); - __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 3)); - __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 4)); - __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 5)); - __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 6)); - __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 7)); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 32)); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 48)); + __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 64)); + __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 80)); + __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 96)); + __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 112)); +#if __AVX512BF16__ + transpose8x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); +#else // __AVX512BF16__ transpose16x8_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm512_store_ps(pp, bfloat2float_avx512(_r0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_r1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_r2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_r3)); - _mm512_store_ps(pp + 16 * 4, bfloat2float_avx512(_r4)); - _mm512_store_ps(pp + 16 * 5, bfloat2float_avx512(_r5)); - _mm512_store_ps(pp + 16 * 6, bfloat2float_avx512(_r6)); - _mm512_store_ps(pp + 16 * 7, bfloat2float_avx512(_r7)); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + _mm256_storeu_si256((__m256i*)(pp + 32), _r2); + _mm256_storeu_si256((__m256i*)(pp + 48), _r3); + _mm256_storeu_si256((__m256i*)(pp + 64), _r4); + _mm256_storeu_si256((__m256i*)(pp + 80), _r5); + _mm256_storeu_si256((__m256i*)(pp + 96), _r6); + _mm256_storeu_si256((__m256i*)(pp + 112), _r7); pp += 128; p0 += A_hstep * 16; } @@ -592,70 +671,98 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); - __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 4)); - __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 5)); - __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 6)); - __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 7)); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 16)); + __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 24)); + __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 32)); + __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 40)); + __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 48)); + __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 56)); +#if __AVX512BF16__ + transpose4x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); +#else // __AVX512BF16__ transpose8x8_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_store_ps(pp, bfloat2float_avx(_r0)); - _mm256_store_ps(pp + 8 * 1, bfloat2float_avx(_r1)); - _mm256_store_ps(pp + 8 * 2, bfloat2float_avx(_r2)); - _mm256_store_ps(pp + 8 * 3, bfloat2float_avx(_r3)); - _mm256_store_ps(pp + 8 * 4, bfloat2float_avx(_r4)); - _mm256_store_ps(pp + 8 * 5, bfloat2float_avx(_r5)); - _mm256_store_ps(pp + 8 * 6, bfloat2float_avx(_r6)); - _mm256_store_ps(pp + 8 * 7, bfloat2float_avx(_r7)); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 16), _r2); + _mm_storeu_si128((__m128i*)(pp + 24), _r3); + _mm_storeu_si128((__m128i*)(pp + 32), _r4); + _mm_storeu_si128((__m128i*)(pp + 40), _r5); + _mm_storeu_si128((__m128i*)(pp + 48), _r6); + _mm_storeu_si128((__m128i*)(pp + 56), _r7); pp += 64; p0 += A_hstep * 8; } } +#endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _a1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _a2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _a3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - __m128i _b0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 4)); - __m128i _b1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 5)); - __m128i _b2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 6)); - __m128i _b3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 7)); - transpose8x4_epi16(_a0, _a1, _a2, _a3); - transpose8x4_epi16(_b0, _b1, _b2, _b3); - // _a0 = [col0_rows0-3 | col1_rows0-3], _b0 = [col0_rows4-7 | col1_rows4-7] - _mm_store_ps(pp, bfloat2float_sse(_a0)); - _mm_store_ps(pp + 4 * 1, bfloat2float_sse(_b0)); - _mm_store_ps(pp + 4 * 2, bfloat2float_sse(_mm_unpackhi_epi64(_a0, _a0))); - _mm_store_ps(pp + 4 * 3, bfloat2float_sse(_mm_unpackhi_epi64(_b0, _b0))); - _mm_store_ps(pp + 4 * 4, bfloat2float_sse(_a1)); - _mm_store_ps(pp + 4 * 5, bfloat2float_sse(_b1)); - _mm_store_ps(pp + 4 * 6, bfloat2float_sse(_mm_unpackhi_epi64(_a1, _a1))); - _mm_store_ps(pp + 4 * 7, bfloat2float_sse(_mm_unpackhi_epi64(_b1, _b1))); + __m128i _r0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(p0 + 16)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(p0 + 20)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(p0 + 24)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(p0 + 28)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _t2 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _t3 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _p0 = _mm_unpacklo_epi64(_t0, _t1); + __m128i _p1 = _mm_unpacklo_epi64(_t2, _t3); + __m128i _p2 = _mm_unpackhi_epi64(_t0, _t1); + __m128i _p3 = _mm_unpackhi_epi64(_t2, _t3); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _t2 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _t3 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_t0, _t1); + _r1 = _mm_unpackhi_epi32(_t0, _t1); + _r2 = _mm_unpacklo_epi32(_t2, _t3); + _r3 = _mm_unpackhi_epi32(_t2, _t3); + __m128i _p0 = _mm_unpacklo_epi64(_r0, _r2); + __m128i _p1 = _mm_unpackhi_epi64(_r0, _r2); + __m128i _p2 = _mm_unpacklo_epi64(_r1, _r3); + __m128i _p3 = _mm_unpackhi_epi64(_r1, _r3); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _p0); + _mm_storeu_si128((__m128i*)(pp + 8), _p1); + _mm_storeu_si128((__m128i*)(pp + 16), _p2); + _mm_storeu_si128((__m128i*)(pp + 24), _p3); pp += 32; p0 += A_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep)); + __m128i _p0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _p1 = _mm_unpackhi_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _p0); + _mm_storeu_si128((__m128i*)(pp + 8), _p1); + pp += 16; + p0 += A_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm256_store_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); pp += 8; p0 += A_hstep; } @@ -664,24 +771,28 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX__ for (; ii + 3 < max_ii; ii += 4) { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 1))); - __m512 _r2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 2))); - __m512 _r3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 3))); - transpose16x4_ps(_r0, _r1, _r2, _r3); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 32)); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 48)); +#if __AVX512BF16__ + transpose8x4_epi32(_r0, _r1, _r2, _r3); +#else // __AVX512BF16__ + transpose16x4_epi16(_r0, _r1, _r2, _r3); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + _mm256_storeu_si256((__m256i*)(pp + 32), _r2); + _mm256_storeu_si256((__m256i*)(pp + 48), _r3); pp += 64; p0 += A_hstep * 16; } @@ -689,20 +800,22 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 16)); + __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 24)); +#if __AVX512BF16__ + transpose4x4_epi32(_r0, _r1, _r2, _r3); +#else // __AVX512BF16__ transpose8x4_epi16(_r0, _r1, _r2, _r3); - _mm256_store_ps(pp, bfloat2float_avx(_r0)); - _mm256_store_ps(pp + 8 * 1, bfloat2float_avx(_r1)); - _mm256_store_ps(pp + 8 * 2, bfloat2float_avx(_r2)); - _mm256_store_ps(pp + 8 * 3, bfloat2float_avx(_r3)); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 16), _r2); + _mm_storeu_si128((__m128i*)(pp + 24), _r3); pp += 32; p0 += A_hstep * 8; } @@ -710,33 +823,47 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { __m128i _r0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - transpose8x4_epi16(_r0, _r1, _r2, _r3); - // _r0 = [row0_lo | row1_lo], _r1 = [row2_lo | row3_lo], _r2/_r3 = 0 - _mm_store_ps(pp, bfloat2float_sse(_r0)); - _mm_store_ps(pp + 4 * 1, bfloat2float_sse(_mm_unpackhi_epi64(_r0, _r0))); - _mm_store_ps(pp + 4 * 2, bfloat2float_sse(_r1)); - _mm_store_ps(pp + 4 * 3, bfloat2float_sse(_mm_unpackhi_epi64(_r1, _r1))); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi32(_r2, _r3); + _r0 = _mm_unpacklo_epi64(_t0, _t1); + _r1 = _mm_unpackhi_epi64(_t0, _t1); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_t0, _t1); + _r1 = _mm_unpackhi_epi32(_t0, _t1); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); pp += 16; p0 += A_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _p = _mm_unpacklo_epi16(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 8; + p0 += A_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += A_hstep; } @@ -745,21 +872,25 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16))); - transpose16x2_ps(_r0, _r1); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16, _r1); + __m256i _p0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _p1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); +#if __AVX512BF16__ + transpose8x2_epi32(_p0, _p1); +#else // __AVX512BF16__ + transpose16x2_epi16(_p0, _p1); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _p0); + _mm256_storeu_si256((__m256i*)(pp + 16), _p1); pp += 32; p0 += A_hstep * 16; } @@ -767,16 +898,20 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8))); - transpose8x2_ps(_r0, _r1); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi32(_p0, _p1); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi16(_p0, _p1); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); pp += 16; p0 += A_hstep * 8; } @@ -784,16 +919,28 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a = _mm_loadl_epi64((const __m128i*)p0); - __m128i _b = _mm_loadl_epi64((const __m128i*)(p0 + 4)); - __m128i _tmp0 = _mm_unpacklo_epi16(_a, _b); - _mm_store_ps(pp, bfloat2float_sse(_tmp0)); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_unpackhi_epi64(_tmp0, _tmp0))); +#if __AVX512BF16__ + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[4]; + pp[3] = p0[5]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p0[6]; + pp[7] = p0[7]; +#else // __AVX512BF16__ + pp[0] = p0[0]; + pp[1] = p0[4]; + pp[2] = p0[1]; + pp[3] = p0[5]; + pp[4] = p0[2]; + pp[5] = p0[6]; + pp[6] = p0[3]; + pp[7] = p0[7]; +#endif // __AVX512BF16__ pp += 8; p0 += A_hstep * 4; } @@ -801,13 +948,22 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __SSE2__ if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp += 4; + p0 += A_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p0[1]); + pp[0] = p0[0]; + pp[1] = p0[1]; pp += 2; p0 += A_hstep; } @@ -815,17 +971,17 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii } for (; ii < max_ii; ii += 1) { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - _mm512_store_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); pp += 16; p0 += A_hstep * 16; } @@ -833,12 +989,10 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - _mm256_store_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); pp += 8; p0 += A_hstep * 8; } @@ -846,12 +1000,10 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += A_hstep * 4; } @@ -859,12 +1011,10 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii #endif // __SSE2__ if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); - int kk = 0; for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); + pp[0] = p0[0]; pp += 1; p0 += A_hstep; } @@ -872,38 +1022,74 @@ static void transpose_pack_A_tile_bf16s(const Mat& A, Mat& AT, int i, int max_ii } } -static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +static void pack_B_tile_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pack_B_tile_bf16_avx512bf16(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_B_tile_bf16 %d %d %d %d", j, max_jj, k, max_kk); const int elempack = B.elempack; const size_t B_hstep = B.dims == 3 ? B.cstep : (size_t)B.w; - float* pp = BT; + unsigned short* pp = BT; int jj = 0; #if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) #if __AVX512F__ for (; jj + 15 < max_jj; jj += 16) { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 16; - - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + __m512i _idx = _mm512_set_epi16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _p = _mm512_loadu_si512((const __m512i*)p0); + _p = _mm512_permutexvar_epi16(_idx, _p); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm512_storeu_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); pp += 16; p0 += 16; } } if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 8; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k * 8; + const unsigned short* p1 = p0 + B_hstep * 8; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + __m512i _idx = _mm512_set_epi16(31, 23, 30, 22, 29, 21, 28, 20, 27, 19, 26, 18, 25, 17, 24, 16, 15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _a = _mm256_loadu_si256((const __m256i*)p0); + __m256i _b = _mm256_loadu_si256((const __m256i*)p1); + __m512i _ab = combine8x2_epi32(_a, _b); + __m512i _p = _mm512_permutexvar_epi16(_idx, _ab); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 16; + p1 += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm256_storeu_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - _mm256_storeu_ps(pp + 8, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); + _mm_storeu_si128((__m128i*)(pp + 8), _mm_loadu_si128((const __m128i*)p1)); pp += 16; p0 += 8; p1 += 8; @@ -911,17 +1097,39 @@ static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, i } if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k * 4; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k * 4; - const unsigned short* p3 = (const unsigned short*)B + (j + jj + 12) * B_hstep + k * 4; + const unsigned short* p1 = p0 + B_hstep * 4; + const unsigned short* p2 = p0 + B_hstep * 8; + const unsigned short* p3 = p0 + B_hstep * 12; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _a0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _a1 = _mm_loadu_si128((const __m128i*)p1); + __m128i _a2 = _mm_loadu_si128((const __m128i*)p2); + __m128i _a3 = _mm_loadu_si128((const __m128i*)p3); + __m128i _t0 = _mm_unpacklo_epi16(_a0, _mm_srli_si128(_a0, 8)); + __m128i _t1 = _mm_unpacklo_epi16(_a1, _mm_srli_si128(_a1, 8)); + __m128i _t2 = _mm_unpacklo_epi16(_a2, _mm_srli_si128(_a2, 8)); + __m128i _t3 = _mm_unpacklo_epi16(_a3, _mm_srli_si128(_a3, 8)); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + _mm_storeu_si128((__m128i*)(pp + 16), _t2); + _mm_storeu_si128((__m128i*)(pp + 24), _t3); + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - _mm_store_ps(pp + 8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2))); - _mm_store_ps(pp + 12, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); + _mm_storel_epi64((__m128i*)(pp + 4), _mm_loadl_epi64((const __m128i*)p1)); + _mm_storel_epi64((__m128i*)(pp + 8), _mm_loadl_epi64((const __m128i*)p2)); + _mm_storel_epi64((__m128i*)(pp + 12), _mm_loadl_epi64((const __m128i*)p3)); pp += 16; p0 += 4; p1 += 4; @@ -931,531 +1139,289 @@ static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, i } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; - const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; - const unsigned short* p4 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k; - const unsigned short* p5 = (const unsigned short*)B + (j + jj + 5) * B_hstep + k; - const unsigned short* p6 = (const unsigned short*)B + (j + jj + 6) * B_hstep + k; - const unsigned short* p7 = (const unsigned short*)B + (j + jj + 7) * B_hstep + k; - const unsigned short* p8 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k; - const unsigned short* p9 = (const unsigned short*)B + (j + jj + 9) * B_hstep + k; - const unsigned short* pa = (const unsigned short*)B + (j + jj + 10) * B_hstep + k; - const unsigned short* pb = (const unsigned short*)B + (j + jj + 11) * B_hstep + k; - const unsigned short* pc = (const unsigned short*)B + (j + jj + 12) * B_hstep + k; - const unsigned short* pd = (const unsigned short*)B + (j + jj + 13) * B_hstep + k; - const unsigned short* pe = (const unsigned short*)B + (j + jj + 14) * B_hstep + k; - const unsigned short* pf = (const unsigned short*)B + (j + jj + 15) * B_hstep + k; + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(B_hstep)); int kk = 0; - for (; kk + 15 < max_kk; kk += 16) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p1)); - __m512 _r2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p2)); - __m512 _r3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p3)); - __m512 _r4 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p4)); - __m512 _r5 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p5)); - __m512 _r6 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p6)); - __m512 _r7 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p7)); - __m512 _r8 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p8)); - __m512 _r9 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p9)); - __m512 _ra = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pa)); - __m512 _rb = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pb)); - __m512 _rc = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pc)); - __m512 _rd = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pd)); - __m512 _re = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pe)); - __m512 _rf = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pf)); - transpose16x16_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm512_storeu_ps(pp, _r0); - _mm512_storeu_ps(pp + 16, _r1); - _mm512_storeu_ps(pp + 16 * 2, _r2); - _mm512_storeu_ps(pp + 16 * 3, _r3); - _mm512_storeu_ps(pp + 16 * 4, _r4); - _mm512_storeu_ps(pp + 16 * 5, _r5); - _mm512_storeu_ps(pp + 16 * 6, _r6); - _mm512_storeu_ps(pp + 16 * 7, _r7); - _mm512_storeu_ps(pp + 16 * 8, _r8); - _mm512_storeu_ps(pp + 16 * 9, _r9); - _mm512_storeu_ps(pp + 16 * 10, _ra); - _mm512_storeu_ps(pp + 16 * 11, _rb); - _mm512_storeu_ps(pp + 16 * 12, _rc); - _mm512_storeu_ps(pp + 16 * 13, _rd); - _mm512_storeu_ps(pp + 16 * 14, _re); - _mm512_storeu_ps(pp + 16 * 15, _rf); - pp += 256; - p0 += 16; - p1 += 16; - p2 += 16; - p3 += 16; - p4 += 16; - p5 += 16; - p6 += 16; - p7 += 16; - p8 += 16; - p9 += 16; - pa += 16; - pb += 16; - pc += 16; - pd += 16; - pe += 16; - pf += 16; + __m512i _p = _mm512_i32gather_epi32(_vindex, (const int*)p0, sizeof(unsigned short)); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 32; + p0 += 2; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); - pp[4] = bfloat16_to_float32(p4[0]); - pp[5] = bfloat16_to_float32(p5[0]); - pp[6] = bfloat16_to_float32(p6[0]); - pp[7] = bfloat16_to_float32(p7[0]); - pp[8] = bfloat16_to_float32(p8[0]); - pp[9] = bfloat16_to_float32(p9[0]); - pp[10] = bfloat16_to_float32(pa[0]); - pp[11] = bfloat16_to_float32(pb[0]); - pp[12] = bfloat16_to_float32(pc[0]); - pp[13] = bfloat16_to_float32(pd[0]); - pp[14] = bfloat16_to_float32(pe[0]); - pp[15] = bfloat16_to_float32(pf[0]); + __m512i _p = _mm512_i32gather_epi32(_vindex, (const int*)p0, sizeof(unsigned short)); + __m256i _p16 = _mm512_cvtepi32_epi16(_p); + _mm256_storeu_si256((__m256i*)pp, _p16); pp += 16; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; - pc++; - pd++; - pe++; - pf++; } } } -#else // __AVX512F__ - for (; jj + 11 < max_jj; jj += 12) +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + #if __AVX__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; - const unsigned short* p1 = (const unsigned short*)B + ((j + jj) / 8 * 8 + 8) * B_hstep + k * 8; - - if ((j + jj) % 8 == 0) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - for (int kk = 0; kk < max_kk; kk++) - { - _mm256_storeu_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - _mm_store_ps(pp + 8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - pp += 12; - p0 += 8; - p1 += 8; - } + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpackhi_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + pp += 16; + p0 += 16; } - if ((j + jj) % 8 == 4) +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - for (int kk = 0; kk < max_kk; kk++) - { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4)))); - _mm256_storeu_ps(pp + 4, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1))); - pp += 12; - p0 += 8; - p1 += 8; - } + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); + pp += 8; + p0 += 8; } } #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k * 4; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k * 4; + const unsigned short* p1 = p0 + B_hstep * 4; - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - _mm_store_ps(pp + 8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2))); - pp += 12; + __m128i _a0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _b0 = _mm_loadu_si128((const __m128i*)p1); + __m128i _t0 = _mm_unpacklo_epi16(_a0, _mm_srli_si128(_a0, 8)); + __m128i _t1 = _mm_unpacklo_epi16(_b0, _mm_srli_si128(_b0, 8)); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); + pp += 16; + p0 += 8; + p1 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); + _mm_storel_epi64((__m128i*)(pp + 4), _mm_loadl_epi64((const __m128i*)p1)); + pp += 8; p0 += 4; p1 += 4; - p2 += 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; - const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; - const unsigned short* p4 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k; - const unsigned short* p5 = (const unsigned short*)B + (j + jj + 5) * B_hstep + k; - const unsigned short* p6 = (const unsigned short*)B + (j + jj + 6) * B_hstep + k; - const unsigned short* p7 = (const unsigned short*)B + (j + jj + 7) * B_hstep + k; - const unsigned short* p8 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k; - const unsigned short* p9 = (const unsigned short*)B + (j + jj + 9) * B_hstep + k; - const unsigned short* pa = (const unsigned short*)B + (j + jj + 10) * B_hstep + k; - const unsigned short* pb = (const unsigned short*)B + (j + jj + 11) * B_hstep + k; +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32((int)B_hstep)); +#endif int kk = 0; -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p2)); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p3)); - __m256 _r4 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p4)); - __m256 _r5 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p5)); - __m256 _r6 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p6)); - __m256 _r7 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p7)); - __m256 _r8 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p8)); - __m256 _r9 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p9)); - __m256 _ra = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pa)); - __m256 _rb = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pb)); - transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm256_storeu_ps(pp, _r0); - _mm256_storeu_ps(pp + 8, _r1); - _mm256_storeu_ps(pp + 8 * 2, _r2); - _mm256_storeu_ps(pp + 8 * 3, _r3); - _mm256_storeu_ps(pp + 8 * 4, _r4); - _mm256_storeu_ps(pp + 8 * 5, _r5); - _mm256_storeu_ps(pp + 8 * 6, _r6); - _mm256_storeu_ps(pp + 8 * 7, _r7); - _mm256_storeu_ps(pp + 8 * 8, _r8); - _mm256_storeu_ps(pp + 8 * 9, _r9); - _mm256_storeu_ps(pp + 8 * 10, _ra); - _mm256_storeu_ps(pp + 8 * 11, _rb); - pp += 96; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; - p4 += 8; - p5 += 8; - p6 += 8; - p7 += 8; - p8 += 8; - p9 += 8; - pa += 8; - pb += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) - { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _r2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2)); - __m128 _r3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3)); - __m128 _r4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p4)); - __m128 _r5 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p5)); - __m128 _r6 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p6)); - __m128 _r7 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p7)); - __m128 _r8 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p8)); - __m128 _r9 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p9)); - __m128 _ra = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pa)); - __m128 _rb = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pb)); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4, _r4); - _mm_store_ps(pp + 4 * 2, _r8); - _mm_store_ps(pp + 4 * 3, _r1); - _mm_store_ps(pp + 4 * 4, _r5); - _mm_store_ps(pp + 4 * 5, _r9); - _mm_store_ps(pp + 4 * 6, _r2); - _mm_store_ps(pp + 4 * 7, _r6); - _mm_store_ps(pp + 4 * 8, _ra); - _mm_store_ps(pp + 4 * 9, _r3); - _mm_store_ps(pp + 4 * 10, _r7); - _mm_store_ps(pp + 4 * 11, _rb); - pp += 48; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; - p8 += 4; - p9 += 4; - pa += 4; - pb += 4; + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(unsigned short)); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 16; + p0 += 2; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); - pp[4] = bfloat16_to_float32(p4[0]); - pp[5] = bfloat16_to_float32(p5[0]); - pp[6] = bfloat16_to_float32(p6[0]); - pp[7] = bfloat16_to_float32(p7[0]); - pp[8] = bfloat16_to_float32(p8[0]); - pp[9] = bfloat16_to_float32(p9[0]); - pp[10] = bfloat16_to_float32(pa[0]); - pp[11] = bfloat16_to_float32(pb[0]); - pp += 12; +#if __AVX2__ + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(unsigned short)); + __m128i _p16 = _mm256_comp_cvtepi32_epi16(_p); + _mm_storeu_si128((__m128i*)pp, _p16); +#else + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[B_hstep * 4]; + pp[5] = p0[B_hstep * 5]; + pp[6] = p0[B_hstep * 6]; + pp[7] = p0[B_hstep * 7]; +#endif + pp += 8; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; } } } -#endif // __AVX512F__ - for (; jj + 7 < max_jj; jj += 8) - { +#else // defined(__x86_64__) || defined(_M_X64) #if __AVX__ - if (elempack == 8) - { #if __AVX512F__ - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 8; -#else - const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 8) / 8 * 8 * B_hstep + k * 8; + if (elempack == 16) + { + for (; jj + 15 < max_jj; jj += 16) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; - if ((j + jj) % 8 == 0) -#endif + unsigned short* pp1 = pp + max_kk * 4; + unsigned short* pp2 = pp + max_kk * 8; + unsigned short* pp3 = pp + max_kk * 12; + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - for (int kk = 0; kk < max_kk; kk++) - { - _mm256_storeu_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - pp += 8; - p0 += 8; - } + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); + __m128i _p4 = _mm_loadl_epi64((const __m128i*)(p0 + 16)); + __m128i _p5 = _mm_loadl_epi64((const __m128i*)(p0 + 20)); + __m128i _p6 = _mm_loadl_epi64((const __m128i*)(p0 + 24)); + __m128i _p7 = _mm_loadl_epi64((const __m128i*)(p0 + 28)); + + __m128i _t0 = _mm_unpacklo_epi16(_p0, _p1); + __m128i _t1 = _mm_unpacklo_epi16(_p2, _p3); + __m128i _t2 = _mm_unpacklo_epi16(_p4, _p5); + __m128i _t3 = _mm_unpacklo_epi16(_p6, _p7); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)pp1, _t1); + _mm_storeu_si128((__m128i*)pp2, _t2); + _mm_storeu_si128((__m128i*)pp3, _t3); + + pp += 8; + pp1 += 8; + pp2 += 8; + pp3 += 8; + p0 += 32; } -#if !__AVX512F__ - if ((j + jj) % 8 == 4) +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - for (int kk = 0; kk < max_kk; kk++) - { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4)))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - pp += 8; - p0 += 8; - p1 += 8; - } - } -#endif // !__AVX512F__ - } -#endif // __AVX__ - if (elempack == 4) - { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k * 4; + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); - for (int kk = 0; kk < max_kk; kk++) - { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1))); - pp += 8; - p0 += 4; - p1 += 4; + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_p0)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_p0)); + _mm_storel_pd((double*)pp2, _mm_castsi128_pd(_p1)); + _mm_storeh_pd((double*)pp3, _mm_castsi128_pd(_p1)); + + pp += 4; + pp1 += 4; + pp2 += 4; + pp3 += 4; + p0 += 16; } + + pp = pp3; } - if (elempack == 1) + } +#endif // __AVX512F__ + if (elempack == 8) + { + for (; jj + 7 < max_jj; jj += 8) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; - const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; - const unsigned short* p4 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k; - const unsigned short* p5 = (const unsigned short*)B + (j + jj + 5) * B_hstep + k; - const unsigned short* p6 = (const unsigned short*)B + (j + jj + 6) * B_hstep + k; - const unsigned short* p7 = (const unsigned short*)B + (j + jj + 7) * B_hstep + k; + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + unsigned short* pp1 = pp + max_kk * 4; int kk = 0; -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p2)); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p3)); - __m256 _r4 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p4)); - __m256 _r5 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p5)); - __m256 _r6 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p6)); - __m256 _r7 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p7)); - transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_storeu_ps(pp, _r0); - _mm256_storeu_ps(pp + 8, _r1); - _mm256_storeu_ps(pp + 8 * 2, _r2); - _mm256_storeu_ps(pp + 8 * 3, _r3); - _mm256_storeu_ps(pp + 8 * 4, _r4); - _mm256_storeu_ps(pp + 8 * 5, _r5); - _mm256_storeu_ps(pp + 8 * 6, _r6); - _mm256_storeu_ps(pp + 8 * 7, _r7); - pp += 64; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; - p4 += 8; - p5 += 8; - p6 += 8; - p7 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _r2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2)); - __m128 _r3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3)); - __m128 _r4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p4)); - __m128 _r5 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p5)); - __m128 _r6 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p6)); - __m128 _r7 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p7)); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4, _r4); - _mm_store_ps(pp + 4 * 2, _r1); - _mm_store_ps(pp + 4 * 3, _r5); - _mm_store_ps(pp + 4 * 4, _r2); - _mm_store_ps(pp + 4 * 5, _r6); - _mm_store_ps(pp + 4 * 6, _r3); - _mm_store_ps(pp + 4 * 7, _r7); - pp += 32; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); + + __m128i _t0 = _mm_unpacklo_epi16(_p0, _p1); + __m128i _t1 = _mm_unpacklo_epi16(_p2, _p3); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)pp1, _t1); + + pp += 8; + pp1 += 8; + p0 += 16; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); - pp[4] = bfloat16_to_float32(p4[0]); - pp[5] = bfloat16_to_float32(p5[0]); - pp[6] = bfloat16_to_float32(p6[0]); - pp[7] = bfloat16_to_float32(p7[0]); - pp += 8; - p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_p0)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_p0)); + + pp += 4; + pp1 += 4; + p0 += 8; } + + pp = pp1; } } +#endif // __AVX__ +#endif // defined(__x86_64__) || defined(_M_X64) for (; jj + 3 < max_jj; jj += 4) { -#if __AVX__ && !__AVX512F__ - if (elempack == 8) - { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; - if ((j + jj) % 8 == 0) - { - for (int kk = 0; kk < max_kk; kk++) - { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - pp += 4; - p0 += 8; - } - } - if ((j + jj) % 8 == 4) - { - for (int kk = 0; kk < max_kk; kk++) - { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4)))); - pp += 4; - p0 += 8; - } - } - } -#endif // __AVX__ && !__AVX512F__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; - - for (int kk = 0; kk < max_kk; kk++) + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _t0 = _mm_unpacklo_epi16(_r0, _mm_srli_si128(_r0, 8)); + __m128i _t1 = _mm_unpackhi_epi16(_mm_slli_si128(_r0, 8), _r0); + (void)_t1; + _mm_storeu_si128((__m128i*)pp, _t0); + pp += 8; + p0 += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; - const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; - const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; + const unsigned short* p1 = p0 + B_hstep * 1; + const unsigned short* p2 = p0 + B_hstep * 2; + const unsigned short* p3 = p0 + B_hstep * 3; int kk = 0; -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p2)); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p3)); - transpose8x4_ps(_r0, _r1, _r2, _r3); - _mm256_storeu_ps(pp, _r0); - _mm256_storeu_ps(pp + 8, _r1); - _mm256_storeu_ps(pp + 16, _r2); - _mm256_storeu_ps(pp + 24, _r3); - pp += 32; - p0 += 8; - p1 += 8; - p2 += 8; - p3 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _r2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p2)); - __m128 _r3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p3)); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4, _r1); - _mm_store_ps(pp + 8, _r2); - _mm_store_ps(pp + 12, _r3); - pp += 16; - p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); - pp[2] = bfloat16_to_float32(p2[0]); - pp[3] = bfloat16_to_float32(p3[0]); + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; pp += 4; p0++; p1++; @@ -1467,43 +1433,28 @@ static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, i #endif // __SSE2__ for (; jj + 1 < max_jj; jj += 2) { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = p0 + B_hstep; + // if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; - int kk = 0; -#if __SSE2__ -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p1)); - transpose8x2_ps(_r0, _r1); - _mm256_storeu_ps(pp, _r0); - _mm256_storeu_ps(pp + 8, _r1); - pp += 16; - p0 += 8; - p1 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p1)); - __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); - __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); - _mm_store_ps(pp, _tmp0); - _mm_store_ps(pp + 4, _tmp1); - pp += 8; - p0 += 4; - p1 += 4; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; } -#endif // __SSE2__ +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p1[0]); + pp[0] = p0[0]; + pp[1] = p1[0]; pp += 2; p0++; p1++; @@ -1512,30 +1463,23 @@ static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, i } for (; jj < max_jj; jj += 1) { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + // if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; - int kk = 0; -#if __SSE2__ -#if __AVX__ - for (; kk + 7 < max_kk; kk += 8) - { - _mm256_storeu_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); - pp += 8; - p0 += 8; - } -#endif // __AVX__ - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - _mm_storeu_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - pp += 4; - p0 += 4; + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += 2; } -#endif // __SSE2__ +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); + pp[0] = p0[0]; pp += 1; p0++; } @@ -1543,248 +1487,250 @@ static void pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, i } } -static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +static void transpose_pack_B_tile_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + transpose_pack_B_tile_bf16_avx512bf16(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_B_tile_bf16 %d %d %d %d", j, max_jj, k, max_kk); const int elempack = B.elempack; const size_t B_hstep = B.dims == 3 ? B.cstep : (size_t)B.w; - float* pp = BT; + unsigned short* pp = (unsigned short*)BT; int jj = 0; #if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) #if __AVX512F__ for (; jj + 15 < max_jj; jj += 16) { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 1)); - __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 2)); - __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 3)); - __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 4)); - __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 5)); - __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 6)); - __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 7)); - __m256i _r8 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 8)); - __m256i _r9 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 9)); - __m256i _ra = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 10)); - __m256i _rb = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 11)); - __m256i _rc = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 12)); - __m256i _rd = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 13)); - __m256i _re = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 14)); - __m256i _rf = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 15)); - transpose16x16_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm512_store_ps(pp, bfloat2float_avx512(_r0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_r1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_r2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_r3)); - _mm512_store_ps(pp + 16 * 4, bfloat2float_avx512(_r4)); - _mm512_store_ps(pp + 16 * 5, bfloat2float_avx512(_r5)); - _mm512_store_ps(pp + 16 * 6, bfloat2float_avx512(_r6)); - _mm512_store_ps(pp + 16 * 7, bfloat2float_avx512(_r7)); - _mm512_store_ps(pp + 16 * 8, bfloat2float_avx512(_r8)); - _mm512_store_ps(pp + 16 * 9, bfloat2float_avx512(_r9)); - _mm512_store_ps(pp + 16 * 10, bfloat2float_avx512(_ra)); - _mm512_store_ps(pp + 16 * 11, bfloat2float_avx512(_rb)); - _mm512_store_ps(pp + 16 * 12, bfloat2float_avx512(_rc)); - _mm512_store_ps(pp + 16 * 13, bfloat2float_avx512(_rd)); - _mm512_store_ps(pp + 16 * 14, bfloat2float_avx512(_re)); - _mm512_store_ps(pp + 16 * 15, bfloat2float_avx512(_rf)); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _r4 = _mm512_loadu_si512((const __m512i*)(p0 + 128)); + __m512i _r5 = _mm512_loadu_si512((const __m512i*)(p0 + 160)); + __m512i _r6 = _mm512_loadu_si512((const __m512i*)(p0 + 192)); + __m512i _r7 = _mm512_loadu_si512((const __m512i*)(p0 + 224)); + + __m512i w0 = _mm512_shuffle_i64x2(_r0, _r1, 0x44); + __m512i w1 = _mm512_shuffle_i64x2(_r0, _r1, 0xEE); + __m512i w2 = _mm512_shuffle_i64x2(_r2, _r3, 0x44); + __m512i w3 = _mm512_shuffle_i64x2(_r2, _r3, 0xEE); + __m512i w4 = _mm512_shuffle_i64x2(_r4, _r5, 0x44); + __m512i w5 = _mm512_shuffle_i64x2(_r4, _r5, 0xEE); + __m512i w6 = _mm512_shuffle_i64x2(_r6, _r7, 0x44); + __m512i w7 = _mm512_shuffle_i64x2(_r6, _r7, 0xEE); + +#if __AVX512BF16__ + __m512i a0 = _mm512_unpacklo_epi32(w0, w1); + __m512i a1 = _mm512_unpackhi_epi32(w0, w1); + __m512i a2 = _mm512_unpacklo_epi32(w2, w3); + __m512i a3 = _mm512_unpackhi_epi32(w2, w3); + __m512i a4 = _mm512_unpacklo_epi32(w4, w5); + __m512i a5 = _mm512_unpackhi_epi32(w4, w5); + __m512i a6 = _mm512_unpacklo_epi32(w6, w7); + __m512i a7 = _mm512_unpackhi_epi32(w6, w7); + + __m512i b0 = _mm512_unpacklo_epi64(a0, a2); + __m512i b1 = _mm512_unpackhi_epi64(a0, a2); + __m512i b2 = _mm512_unpacklo_epi64(a1, a3); + __m512i b3 = _mm512_unpackhi_epi64(a1, a3); + __m512i b4 = _mm512_unpacklo_epi64(a4, a6); + __m512i b5 = _mm512_unpackhi_epi64(a4, a6); + __m512i b6 = _mm512_unpacklo_epi64(a5, a7); + __m512i b7 = _mm512_unpackhi_epi64(a5, a7); + + __m512i idx_l = _mm512_set_epi32(27, 26, 19, 18, 25, 24, 17, 16, 11, 10, 3, 2, 9, 8, 1, 0); + __m512i idx_r = _mm512_set_epi32(31, 30, 23, 22, 29, 28, 21, 20, 15, 14, 7, 6, 13, 12, 5, 4); + + __m512i _p0 = _mm512_permutex2var_epi32(b0, idx_l, b4); + __m512i _p1 = _mm512_permutex2var_epi32(b1, idx_l, b5); + __m512i _p2 = _mm512_permutex2var_epi32(b2, idx_l, b6); + __m512i _p3 = _mm512_permutex2var_epi32(b3, idx_l, b7); + __m512i _p4 = _mm512_permutex2var_epi32(b0, idx_r, b4); + __m512i _p5 = _mm512_permutex2var_epi32(b1, idx_r, b5); + __m512i _p6 = _mm512_permutex2var_epi32(b2, idx_r, b6); + __m512i _p7 = _mm512_permutex2var_epi32(b3, idx_r, b7); +#else // __AVX512BF16__ + __m512i a0 = _mm512_unpacklo_epi16(w0, w1); + __m512i a1 = _mm512_unpackhi_epi16(w0, w1); + __m512i a2 = _mm512_unpacklo_epi16(w2, w3); + __m512i a3 = _mm512_unpackhi_epi16(w2, w3); + __m512i a4 = _mm512_unpacklo_epi16(w4, w5); + __m512i a5 = _mm512_unpackhi_epi16(w4, w5); + __m512i a6 = _mm512_unpacklo_epi16(w6, w7); + __m512i a7 = _mm512_unpackhi_epi16(w6, w7); + + __m512i b0 = _mm512_unpacklo_epi32(a0, a2); + __m512i b1 = _mm512_unpackhi_epi32(a0, a2); + __m512i b2 = _mm512_unpacklo_epi32(a1, a3); + __m512i b3 = _mm512_unpackhi_epi32(a1, a3); + __m512i b4 = _mm512_unpacklo_epi32(a4, a6); + __m512i b5 = _mm512_unpackhi_epi32(a4, a6); + __m512i b6 = _mm512_unpacklo_epi32(a5, a7); + __m512i b7 = _mm512_unpackhi_epi32(a5, a7); + + __m512i c0 = _mm512_unpacklo_epi64(b0, b4); + __m512i c1 = _mm512_unpackhi_epi64(b0, b4); + __m512i c2 = _mm512_unpacklo_epi64(b1, b5); + __m512i c3 = _mm512_unpackhi_epi64(b1, b5); + __m512i c4 = _mm512_unpacklo_epi64(b2, b6); + __m512i c5 = _mm512_unpackhi_epi64(b2, b6); + __m512i c6 = _mm512_unpacklo_epi64(b3, b7); + __m512i c7 = _mm512_unpackhi_epi64(b3, b7); + + __m512i idx_lo = _mm512_set_epi32(27, 19, 26, 18, 25, 17, 24, 16, 11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx_hi = _mm512_set_epi32(31, 23, 30, 22, 29, 21, 28, 20, 15, 7, 14, 6, 13, 5, 12, 4); + + __m512i _p0 = _mm512_permutex2var_epi32(c0, idx_lo, c1); // col 0,1 + __m512i _p1 = _mm512_permutex2var_epi32(c2, idx_lo, c3); // col 2,3 + __m512i _p2 = _mm512_permutex2var_epi32(c4, idx_lo, c5); // col 4,5 + __m512i _p3 = _mm512_permutex2var_epi32(c6, idx_lo, c7); // col 6,7 + __m512i _p4 = _mm512_permutex2var_epi32(c0, idx_hi, c1); // col 8,9 + __m512i _p5 = _mm512_permutex2var_epi32(c2, idx_hi, c3); // col A,B + __m512i _p6 = _mm512_permutex2var_epi32(c4, idx_hi, c5); // col C,D + __m512i _p7 = _mm512_permutex2var_epi32(c6, idx_hi, c7); // col E,F +#endif // __AVX512BF16__ + + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); + _mm512_storeu_si512((__m512i*)(pp + 64), _p2); + _mm512_storeu_si512((__m512i*)(pp + 96), _p3); + _mm512_storeu_si512((__m512i*)(pp + 128), _p4); + _mm512_storeu_si512((__m512i*)(pp + 160), _p5); + _mm512_storeu_si512((__m512i*)(pp + 192), _p6); + _mm512_storeu_si512((__m512i*)(pp + 224), _p7); pp += 256; p0 += B_hstep * 16; } } if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); - __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 4)); - __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 5)); - __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 6)); - __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 7)); - __m128i _r8 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 8)); - __m128i _r9 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 9)); - __m128i _ra = _mm_loadu_si128((const __m128i*)(p0 + 8 * 10)); - __m128i _rb = _mm_loadu_si128((const __m128i*)(p0 + 8 * 11)); - __m128i _rc = _mm_loadu_si128((const __m128i*)(p0 + 8 * 12)); - __m128i _rd = _mm_loadu_si128((const __m128i*)(p0 + 8 * 13)); - __m128i _re = _mm_loadu_si128((const __m128i*)(p0 + 8 * 14)); - __m128i _rf = _mm_loadu_si128((const __m128i*)(p0 + 8 * 15)); - transpose8x16_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); - _mm256_store_ps(pp, bfloat2float_avx(_r0)); - _mm256_store_ps(pp + 8 * 1, bfloat2float_avx(_r1)); - _mm256_store_ps(pp + 8 * 2, bfloat2float_avx(_r2)); - _mm256_store_ps(pp + 8 * 3, bfloat2float_avx(_r3)); - _mm256_store_ps(pp + 8 * 4, bfloat2float_avx(_r4)); - _mm256_store_ps(pp + 8 * 5, bfloat2float_avx(_r5)); - _mm256_store_ps(pp + 8 * 6, bfloat2float_avx(_r6)); - _mm256_store_ps(pp + 8 * 7, bfloat2float_avx(_r7)); - _mm256_store_ps(pp + 8 * 8, bfloat2float_avx(_r8)); - _mm256_store_ps(pp + 8 * 9, bfloat2float_avx(_r9)); - _mm256_store_ps(pp + 8 * 10, bfloat2float_avx(_ra)); - _mm256_store_ps(pp + 8 * 11, bfloat2float_avx(_rb)); - _mm256_store_ps(pp + 8 * 12, bfloat2float_avx(_rc)); - _mm256_store_ps(pp + 8 * 13, bfloat2float_avx(_rd)); - _mm256_store_ps(pp + 8 * 14, bfloat2float_avx(_re)); - _mm256_store_ps(pp + 8 * 15, bfloat2float_avx(_rf)); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); +#if __AVX512BF16__ + __m512i idx0 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i idx1 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 29, 25, 21, 17, 13, 9, 5, 1); + __m512i idx2 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i idx3 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 31, 27, 23, 19, 15, 11, 7, 3); + + __m512i lo0 = _mm512_permutex2var_epi32(_r0, idx0, _r1); + __m512i lo1 = _mm512_permutex2var_epi32(_r0, idx1, _r1); + __m512i lo2 = _mm512_permutex2var_epi32(_r0, idx2, _r1); + __m512i lo3 = _mm512_permutex2var_epi32(_r0, idx3, _r1); + + __m512i hi0 = _mm512_permutex2var_epi32(_r2, idx0, _r3); + __m512i hi1 = _mm512_permutex2var_epi32(_r2, idx1, _r3); + __m512i hi2 = _mm512_permutex2var_epi32(_r2, idx2, _r3); + __m512i hi3 = _mm512_permutex2var_epi32(_r2, idx3, _r3); + + __m512i _p0 = _mm512_inserti64x4(lo0, _mm512_castsi512_si256(hi0), 1); + __m512i _p1 = _mm512_inserti64x4(lo1, _mm512_castsi512_si256(hi1), 1); + __m512i _p2 = _mm512_inserti64x4(lo2, _mm512_castsi512_si256(hi2), 1); + __m512i _p3 = _mm512_inserti64x4(lo3, _mm512_castsi512_si256(hi3), 1); +#else // __AVX512BF16__ + __m512i id0 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i id1 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 29, 25, 21, 17, 13, 9, 5, 1); + __m512i id2 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i id3 = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 31, 27, 23, 19, 15, 11, 7, 3); + + __m512i p0_lo = _mm512_permutex2var_epi32(_r0, id0, _r1); + __m512i p1_lo = _mm512_permutex2var_epi32(_r0, id1, _r1); + __m512i p2_lo = _mm512_permutex2var_epi32(_r0, id2, _r1); + __m512i p3_lo = _mm512_permutex2var_epi32(_r0, id3, _r1); + + __m512i p0_hi = _mm512_permutex2var_epi32(_r2, id0, _r3); + __m512i p1_hi = _mm512_permutex2var_epi32(_r2, id1, _r3); + __m512i p2_hi = _mm512_permutex2var_epi32(_r2, id2, _r3); + __m512i p3_hi = _mm512_permutex2var_epi32(_r2, id3, _r3); + + __m512i cp0 = _mm512_inserti64x4(p0_lo, _mm512_castsi512_si256(p0_hi), 1); + __m512i cp1 = _mm512_inserti64x4(p1_lo, _mm512_castsi512_si256(p1_hi), 1); + __m512i cp2 = _mm512_inserti64x4(p2_lo, _mm512_castsi512_si256(p2_hi), 1); + __m512i cp3 = _mm512_inserti64x4(p3_lo, _mm512_castsi512_si256(p3_hi), 1); + + __m512i shuf = _mm512_set4_epi32(0x0f0e0b0a, 0x07060302, 0x0d0c0908, 0x05040100); + __m512i pq = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + __m512i s0 = _mm512_shuffle_epi8(cp0, shuf); + __m512i s1 = _mm512_shuffle_epi8(cp1, shuf); + __m512i s2 = _mm512_shuffle_epi8(cp2, shuf); + __m512i s3 = _mm512_shuffle_epi8(cp3, shuf); + + __m512i _p0 = _mm512_permutexvar_epi64(pq, s0); + __m512i _p1 = _mm512_permutexvar_epi64(pq, s1); + __m512i _p2 = _mm512_permutexvar_epi64(pq, s2); + __m512i _p3 = _mm512_permutexvar_epi64(pq, s3); +#endif // __AVX512BF16__ + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); + _mm512_storeu_si512((__m512i*)(pp + 64), _p2); + _mm512_storeu_si512((__m512i*)(pp + 96), _p3); pp += 128; p0 += B_hstep * 8; } } if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _a1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _a2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _a3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - __m128i _b0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 4)); - __m128i _b1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 5)); - __m128i _b2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 6)); - __m128i _b3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 7)); - __m128i _c0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 8)); - __m128i _c1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 9)); - __m128i _c2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 10)); - __m128i _c3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 11)); - __m128i _d0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 12)); - __m128i _d1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 13)); - __m128i _d2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 14)); - __m128i _d3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 15)); - transpose8x4_epi16(_a0, _a1, _a2, _a3); - transpose8x4_epi16(_b0, _b1, _b2, _b3); - transpose8x4_epi16(_c0, _c1, _c2, _c3); - transpose8x4_epi16(_d0, _d1, _d2, _d3); - __m256i _col0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpacklo_epi64(_a0, _b0)), _mm_unpacklo_epi64(_c0, _d0), 1); - __m256i _col1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpackhi_epi64(_a0, _b0)), _mm_unpackhi_epi64(_c0, _d0), 1); - __m256i _col2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpacklo_epi64(_a1, _b1)), _mm_unpacklo_epi64(_c1, _d1), 1); - __m256i _col3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_unpackhi_epi64(_a1, _b1)), _mm_unpackhi_epi64(_c1, _d1), 1); - _mm512_store_ps(pp, bfloat2float_avx512(_col0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_col1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_col2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_col3)); + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); +#if __AVX512BF16__ + __m512i idx_lo = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30); + __m512i idx_hi = _mm512_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31); + __m512i _p0 = _mm512_permutex2var_epi32(_r0, idx_lo, _r1); + __m512i _p1 = _mm512_permutex2var_epi32(_r0, idx_hi, _r1); +#else // __AVX512BF16__ + __m512i idx_lo = _mm512_set_epi16(61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0); + __m512i idx_hi = _mm512_set_epi16(63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2); + __m512i _p0 = _mm512_permutex2var_epi16(_r0, idx_lo, _r1); + __m512i _p1 = _mm512_permutex2var_epi16(_r0, idx_hi, _r1); +#endif // __AVX512BF16__ + _mm512_storeu_si512((__m512i*)pp, _p0); + _mm512_storeu_si512((__m512i*)(pp + 32), _p1); pp += 64; p0 += B_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - - int kk = 0; - for (; kk < max_kk; kk++) - { - _mm512_storeu_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); - pp += 16; - p0 += B_hstep; - } - } - } -#else // __AVX512F__ - for (; jj + 11 < max_jj; jj += 12) - { -#if __AVX__ - if (elempack == 8) - { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - - int kk = 0; - for (; kk + 7 < max_kk; kk += 8) - { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 1))); - __m256 _r2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 2))); - __m256 _r3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 3))); - __m256 _r4 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 4))); - __m256 _r5 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 5))); - __m256 _r6 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 6))); - __m256 _r7 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 7))); - __m256 _r8 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 8))); - __m256 _r9 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 9))); - __m256 _ra = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 10))); - __m256 _rb = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8 * 11))); - transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - _mm256_store_ps(pp + 8 * 8, _r8); - _mm256_store_ps(pp + 8 * 9, _r9); - _mm256_store_ps(pp + 8 * 10, _ra); - _mm256_store_ps(pp + 8 * 11, _rb); - pp += 96; - p0 += B_hstep * 8; - } - } -#endif // __AVX__ - if (elempack == 4) - { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) { - __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0)); - __m128 _r1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 1))); - __m128 _r2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 2))); - __m128 _r3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 3))); - __m128 _r4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 4))); - __m128 _r5 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 5))); - __m128 _r6 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 6))); - __m128 _r7 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 7))); - __m128 _r8 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 8))); - __m128 _r9 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 9))); - __m128 _ra = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 10))); - __m128 _rb = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4 * 11))); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r4); - _mm_store_ps(pp + 4 * 2, _r8); - _mm_store_ps(pp + 4 * 3, _r1); - _mm_store_ps(pp + 4 * 4, _r5); - _mm_store_ps(pp + 4 * 5, _r9); - _mm_store_ps(pp + 4 * 6, _r2); - _mm_store_ps(pp + 4 * 7, _r6); - _mm_store_ps(pp + 4 * 8, _ra); - _mm_store_ps(pp + 4 * 9, _r3); - _mm_store_ps(pp + 4 * 10, _r7); - _mm_store_ps(pp + 4 * 11, _rb); - pp += 48; - p0 += B_hstep * 4; + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + B_hstep)); + transpose16x2_epi16(_r0, _r1); + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + pp += 32; + p0 += B_hstep * 2; } - } - if (elempack == 1) - { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - - int kk = 0; +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4)))); - _mm_store_ps(pp + 8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 8)))); - pp += 12; + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); + pp += 16; p0 += B_hstep; } } @@ -1792,32 +1738,36 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); - __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 2)); - __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 3)); - __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 4)); - __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 5)); - __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 6)); - __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 16 * 7)); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 32)); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 48)); + __m256i _r4 = _mm256_loadu_si256((const __m256i*)(p0 + 64)); + __m256i _r5 = _mm256_loadu_si256((const __m256i*)(p0 + 80)); + __m256i _r6 = _mm256_loadu_si256((const __m256i*)(p0 + 96)); + __m256i _r7 = _mm256_loadu_si256((const __m256i*)(p0 + 112)); +#if __AVX512BF16__ + transpose8x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); +#else // __AVX512BF16__ transpose16x8_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm512_store_ps(pp, bfloat2float_avx512(_r0)); - _mm512_store_ps(pp + 16 * 1, bfloat2float_avx512(_r1)); - _mm512_store_ps(pp + 16 * 2, bfloat2float_avx512(_r2)); - _mm512_store_ps(pp + 16 * 3, bfloat2float_avx512(_r3)); - _mm512_store_ps(pp + 16 * 4, bfloat2float_avx512(_r4)); - _mm512_store_ps(pp + 16 * 5, bfloat2float_avx512(_r5)); - _mm512_store_ps(pp + 16 * 6, bfloat2float_avx512(_r6)); - _mm512_store_ps(pp + 16 * 7, bfloat2float_avx512(_r7)); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + _mm256_storeu_si256((__m256i*)(pp + 32), _r2); + _mm256_storeu_si256((__m256i*)(pp + 48), _r3); + _mm256_storeu_si256((__m256i*)(pp + 64), _r4); + _mm256_storeu_si256((__m256i*)(pp + 80), _r5); + _mm256_storeu_si256((__m256i*)(pp + 96), _r6); + _mm256_storeu_si256((__m256i*)(pp + 112), _r7); pp += 128; p0 += B_hstep * 16; } @@ -1825,28 +1775,30 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); - __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 4)); - __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 5)); - __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 6)); - __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 7)); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 16)); + __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 24)); + __m128i _r4 = _mm_loadu_si128((const __m128i*)(p0 + 32)); + __m128i _r5 = _mm_loadu_si128((const __m128i*)(p0 + 40)); + __m128i _r6 = _mm_loadu_si128((const __m128i*)(p0 + 48)); + __m128i _r7 = _mm_loadu_si128((const __m128i*)(p0 + 56)); +#if __AVX512BF16__ + transpose4x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); +#else // __AVX512BF16__ transpose8x8_epi16(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_store_ps(pp, bfloat2float_avx(_r0)); - _mm256_store_ps(pp + 8 * 1, bfloat2float_avx(_r1)); - _mm256_store_ps(pp + 8 * 2, bfloat2float_avx(_r2)); - _mm256_store_ps(pp + 8 * 3, bfloat2float_avx(_r3)); - _mm256_store_ps(pp + 8 * 4, bfloat2float_avx(_r4)); - _mm256_store_ps(pp + 8 * 5, bfloat2float_avx(_r5)); - _mm256_store_ps(pp + 8 * 6, bfloat2float_avx(_r6)); - _mm256_store_ps(pp + 8 * 7, bfloat2float_avx(_r7)); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 16), _r2); + _mm_storeu_si128((__m128i*)(pp + 24), _r3); + _mm_storeu_si128((__m128i*)(pp + 32), _r4); + _mm_storeu_si128((__m128i*)(pp + 40), _r5); + _mm_storeu_si128((__m128i*)(pp + 48), _r6); + _mm_storeu_si128((__m128i*)(pp + 56), _r7); pp += 64; p0 += B_hstep * 8; } @@ -1854,68 +1806,97 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _a1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _a2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _a3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - __m128i _b0 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 4)); - __m128i _b1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 5)); - __m128i _b2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 6)); - __m128i _b3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 7)); - transpose8x4_epi16(_a0, _a1, _a2, _a3); - transpose8x4_epi16(_b0, _b1, _b2, _b3); - // _a0 = [col0_rows0-3 | col1_rows0-3], _b0 = [col0_rows4-7 | col1_rows4-7] - _mm_store_ps(pp, bfloat2float_sse(_a0)); - _mm_store_ps(pp + 4 * 1, bfloat2float_sse(_b0)); - _mm_store_ps(pp + 4 * 2, bfloat2float_sse(_mm_unpackhi_epi64(_a0, _a0))); - _mm_store_ps(pp + 4 * 3, bfloat2float_sse(_mm_unpackhi_epi64(_b0, _b0))); - _mm_store_ps(pp + 4 * 4, bfloat2float_sse(_a1)); - _mm_store_ps(pp + 4 * 5, bfloat2float_sse(_b1)); - _mm_store_ps(pp + 4 * 6, bfloat2float_sse(_mm_unpackhi_epi64(_a1, _a1))); - _mm_store_ps(pp + 4 * 7, bfloat2float_sse(_mm_unpackhi_epi64(_b1, _b1))); + __m128i _r0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(p0 + 16)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(p0 + 20)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(p0 + 24)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(p0 + 28)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _t2 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _t3 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _p0 = _mm_unpacklo_epi64(_t0, _t1); + __m128i _p1 = _mm_unpacklo_epi64(_t2, _t3); + __m128i _p2 = _mm_unpackhi_epi64(_t0, _t1); + __m128i _p3 = _mm_unpackhi_epi64(_t2, _t3); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _t2 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _t3 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_t0, _t1); + _r1 = _mm_unpackhi_epi32(_t0, _t1); + _r2 = _mm_unpacklo_epi32(_t2, _t3); + _r3 = _mm_unpackhi_epi32(_t2, _t3); + __m128i _p0 = _mm_unpacklo_epi64(_r0, _r2); + __m128i _p1 = _mm_unpackhi_epi64(_r0, _r2); + __m128i _p2 = _mm_unpacklo_epi64(_r1, _r3); + __m128i _p3 = _mm_unpackhi_epi64(_r1, _r3); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _p0); + _mm_storeu_si128((__m128i*)(pp + 8), _p1); + _mm_storeu_si128((__m128i*)(pp + 16), _p2); + _mm_storeu_si128((__m128i*)(pp + 24), _p3); pp += 32; p0 += B_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep)); + __m128i _p0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _p1 = _mm_unpackhi_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _p0); + _mm_storeu_si128((__m128i*)(pp + 8), _p1); + pp += 16; + p0 += B_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(p0 + 4)))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); pp += 8; p0 += B_hstep; } } } +#endif // defined(__x86_64__) || defined(_M_X64) for (; jj + 3 < max_jj; jj += 4) { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 1))); - __m512 _r2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 2))); - __m512 _r3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16 * 3))); - transpose16x4_ps(_r0, _r1, _r2, _r3); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)(p0 + 32)); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)(p0 + 48)); +#if __AVX512BF16__ + transpose8x4_epi32(_r0, _r1, _r2, _r3); +#else // __AVX512BF16__ + transpose16x4_epi16(_r0, _r1, _r2, _r3); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + _mm256_storeu_si256((__m256i*)(pp + 32), _r2); + _mm256_storeu_si256((__m256i*)(pp + 48), _r3); pp += 64; p0 += B_hstep * 16; } @@ -1923,20 +1904,22 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 1)); - __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 2)); - __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 8 * 3)); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_loadu_si128((const __m128i*)(p0 + 16)); + __m128i _r3 = _mm_loadu_si128((const __m128i*)(p0 + 24)); +#if __AVX512BF16__ + transpose4x4_epi32(_r0, _r1, _r2, _r3); +#else // __AVX512BF16__ transpose8x4_epi16(_r0, _r1, _r2, _r3); - _mm256_store_ps(pp, bfloat2float_avx(_r0)); - _mm256_store_ps(pp + 8 * 1, bfloat2float_avx(_r1)); - _mm256_store_ps(pp + 8 * 2, bfloat2float_avx(_r2)); - _mm256_store_ps(pp + 8 * 3, bfloat2float_avx(_r3)); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 16), _r2); + _mm_storeu_si128((__m128i*)(pp + 24), _r3); pp += 32; p0 += B_hstep * 8; } @@ -1944,33 +1927,47 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { __m128i _r0 = _mm_loadl_epi64((const __m128i*)p0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 1)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 2)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 4 * 3)); - transpose8x4_epi16(_r0, _r1, _r2, _r3); - // _r0 = [col0_rows0-3 | col1_rows0-3], _r1 = [col2_rows0-3 | col3_rows0-3] - _mm_store_ps(pp, bfloat2float_sse(_r0)); - _mm_store_ps(pp + 4 * 1, bfloat2float_sse(_mm_unpackhi_epi64(_r0, _r0))); - _mm_store_ps(pp + 4 * 2, bfloat2float_sse(_r1)); - _mm_store_ps(pp + 4 * 3, bfloat2float_sse(_mm_unpackhi_epi64(_r1, _r1))); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(p0 + 4)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(p0 + 8)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(p0 + 12)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi32(_r2, _r3); + _r0 = _mm_unpacklo_epi64(_t0, _t1); + _r1 = _mm_unpackhi_epi64(_t0, _t1); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _t1 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_t0, _t1); + _r1 = _mm_unpackhi_epi32(_t0, _t1); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); pp += 16; p0 += B_hstep * 4; } } if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _p = _mm_unpacklo_epi16(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 8; + p0 += B_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += B_hstep; } @@ -1979,21 +1976,25 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __SSE2__ for (; jj + 1 < max_jj; jj += 2) { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - __m512 _r0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0)); - __m512 _r1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(p0 + 16))); - transpose16x2_ps(_r0, _r1); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16, _r1); + __m256i _p0 = _mm256_loadu_si256((const __m256i*)p0); + __m256i _p1 = _mm256_loadu_si256((const __m256i*)(p0 + 16)); +#if __AVX512BF16__ + transpose8x2_epi32(_p0, _p1); +#else // __AVX512BF16__ + transpose16x2_epi16(_p0, _p1); +#endif // __AVX512BF16__ + _mm256_storeu_si256((__m256i*)pp, _p0); + _mm256_storeu_si256((__m256i*)(pp + 16), _p1); pp += 32; p0 += B_hstep * 16; } @@ -2001,16 +2002,20 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - __m256 _r0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0)); - __m256 _r1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(p0 + 8))); - transpose8x2_ps(_r0, _r1); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); +#if __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi32(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi32(_p0, _p1); +#else // __AVX512BF16__ + __m128i _t0 = _mm_unpacklo_epi16(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi16(_p0, _p1); +#endif // __AVX512BF16__ + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 8), _t1); pp += 16; p0 += B_hstep * 8; } @@ -2018,16 +2023,28 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - __m128i _a = _mm_loadl_epi64((const __m128i*)p0); - __m128i _b = _mm_loadl_epi64((const __m128i*)(p0 + 4)); - __m128i _tmp0 = _mm_unpacklo_epi16(_a, _b); - _mm_store_ps(pp, bfloat2float_sse(_tmp0)); - _mm_store_ps(pp + 4, bfloat2float_sse(_mm_unpackhi_epi64(_tmp0, _tmp0))); +#if __AVX512BF16__ + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[4]; + pp[3] = p0[5]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p0[6]; + pp[7] = p0[7]; +#else // __AVX512BF16__ + pp[0] = p0[0]; + pp[1] = p0[4]; + pp[2] = p0[1]; + pp[3] = p0[5]; + pp[4] = p0[2]; + pp[5] = p0[6]; + pp[6] = p0[3]; + pp[7] = p0[7]; +#endif // __AVX512BF16__ pp += 8; p0 += B_hstep * 4; } @@ -2035,13 +2052,22 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __SSE2__ if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp += 4; + p0 += B_hstep * 2; + } +#endif // __AVX512BF16__ for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); - pp[1] = bfloat16_to_float32(p0[1]); + pp[0] = p0[0]; + pp[1] = p0[1]; pp += 2; p0 += B_hstep; } @@ -2049,17 +2075,17 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj } for (; jj < max_jj; jj += 1) { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 16; - int kk = 0; for (; kk + 15 < max_kk; kk += 16) { - _mm512_store_ps(pp, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)p0))); + _mm256_storeu_si256((__m256i*)pp, _mm256_loadu_si256((const __m256i*)p0)); pp += 16; p0 += B_hstep * 16; } @@ -2067,12 +2093,10 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX512F__ if (elempack == 8) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; - int kk = 0; for (; kk + 7 < max_kk; kk += 8) { - _mm256_store_ps(pp, bfloat2float_avx(_mm_loadu_si128((const __m128i*)p0))); + _mm_storeu_si128((__m128i*)pp, _mm_loadu_si128((const __m128i*)p0)); pp += 8; p0 += B_hstep * 8; } @@ -2080,12 +2104,10 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __AVX__ if (elempack == 4) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; - int kk = 0; for (; kk + 3 < max_kk; kk += 4) { - _mm_store_ps(pp, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)p0))); + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)p0)); pp += 4; p0 += B_hstep * 4; } @@ -2093,12 +2115,10 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj #endif // __SSE2__ if (elempack == 1) { - const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); - int kk = 0; for (; kk < max_kk; kk++) { - pp[0] = bfloat16_to_float32(p0[0]); + pp[0] = p0[0]; pp += 1; p0 += B_hstep; } @@ -2106,158 +2126,6669 @@ static void transpose_pack_B_tile_bf16s(const Mat& B, Mat& BT, int j, int max_jj } } -static void unpack_output_tile_bf16s(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj, int output_transpose) +static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) { - // topT is fp32 packed tile data - // top_blob output is bf16 - const int out_elempack = top_blob.elempack; - const size_t out_hstep = top_blob.dims == 3 ? top_blob.cstep : (size_t)top_blob.w; - +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + gemm_transB_packed_tile_bf16s_avx512bf16(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("gemm_transB_packed_tile_bf16s %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + // actually we only depend the global k==0 condition + (void)i; + (void)j; + + const unsigned short* pAT = AT_tile; + const unsigned short* pBT = BT_tile; + + float* outptr = topT_tile; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const unsigned short* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + __m512 _sum4 = _mm512_setzero_ps(); + __m512 _sum5 = _mm512_setzero_ps(); + __m512 _sum6 = _mm512_setzero_ps(); + __m512 _sum7 = _mm512_setzero_ps(); + __m512 _sum8 = _mm512_setzero_ps(); + __m512 _sum9 = _mm512_setzero_ps(); + __m512 _suma = _mm512_setzero_ps(); + __m512 _sumb = _mm512_setzero_ps(); + __m512 _sumc = _mm512_setzero_ps(); + __m512 _sumd = _mm512_setzero_ps(); + __m512 _sume = _mm512_setzero_ps(); + __m512 _sumf = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 32); + _sum3 = _mm512_load_ps(outptr + 48); + _sum4 = _mm512_load_ps(outptr + 64); + _sum5 = _mm512_load_ps(outptr + 80); + _sum6 = _mm512_load_ps(outptr + 96); + _sum7 = _mm512_load_ps(outptr + 112); + _sum8 = _mm512_load_ps(outptr + 128); + _sum9 = _mm512_load_ps(outptr + 128 + 16); + _suma = _mm512_load_ps(outptr + 128 + 32); + _sumb = _mm512_load_ps(outptr + 128 + 48); + _sumc = _mm512_load_ps(outptr + 128 + 64); + _sumd = _mm512_load_ps(outptr + 128 + 80); + _sume = _mm512_load_ps(outptr + 128 + 96); + _sumf = _mm512_load_ps(outptr + 128 + 112); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pA3 = _mm512_alignr_epi8(_pA2, _pA2, 8); + + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); + + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA0, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA0, (__m512bh)_pB1); + _sum2 = _mm512_dpbf16_ps(_sum2, (__m512bh)_pA1, (__m512bh)_pB0); + _sum3 = _mm512_dpbf16_ps(_sum3, (__m512bh)_pA1, (__m512bh)_pB1); + _sum4 = _mm512_dpbf16_ps(_sum4, (__m512bh)_pA0, (__m512bh)_pB2); + _sum5 = _mm512_dpbf16_ps(_sum5, (__m512bh)_pA0, (__m512bh)_pB3); + _sum6 = _mm512_dpbf16_ps(_sum6, (__m512bh)_pA1, (__m512bh)_pB2); + _sum7 = _mm512_dpbf16_ps(_sum7, (__m512bh)_pA1, (__m512bh)_pB3); + _sum8 = _mm512_dpbf16_ps(_sum8, (__m512bh)_pA2, (__m512bh)_pB0); + _sum9 = _mm512_dpbf16_ps(_sum9, (__m512bh)_pA2, (__m512bh)_pB1); + _suma = _mm512_dpbf16_ps(_suma, (__m512bh)_pA3, (__m512bh)_pB0); + _sumb = _mm512_dpbf16_ps(_sumb, (__m512bh)_pA3, (__m512bh)_pB1); + _sumc = _mm512_dpbf16_ps(_sumc, (__m512bh)_pA2, (__m512bh)_pB2); + _sumd = _mm512_dpbf16_ps(_sumd, (__m512bh)_pA2, (__m512bh)_pB3); + _sume = _mm512_dpbf16_ps(_sume, (__m512bh)_pA3, (__m512bh)_pB2); + _sumf = _mm512_dpbf16_ps(_sumf, (__m512bh)_pA3, (__m512bh)_pB3); + + pA += 32; + pB += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pA)); + __m512 _pB0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pB)); + + __m512 _pA1 = _mm512_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512 _pA2 = _mm512_shuffle_f32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 _pB2 = _mm512_shuffle_f32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pA3 = _mm512_permute_ps(_pA2, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB3 = _mm512_permute_ps(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm512_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm512_fmadd_ps(_pA1, _pB1, _sum3); + _sum4 = _mm512_fmadd_ps(_pA0, _pB2, _sum4); + _sum5 = _mm512_fmadd_ps(_pA0, _pB3, _sum5); + _sum6 = _mm512_fmadd_ps(_pA1, _pB2, _sum6); + _sum7 = _mm512_fmadd_ps(_pA1, _pB3, _sum7); + _sum8 = _mm512_fmadd_ps(_pA2, _pB0, _sum8); + _sum9 = _mm512_fmadd_ps(_pA2, _pB1, _sum9); + _suma = _mm512_fmadd_ps(_pA3, _pB0, _suma); + _sumb = _mm512_fmadd_ps(_pA3, _pB1, _sumb); + _sumc = _mm512_fmadd_ps(_pA2, _pB2, _sumc); + _sumd = _mm512_fmadd_ps(_pA2, _pB3, _sumd); + _sume = _mm512_fmadd_ps(_pA3, _pB2, _sume); + _sumf = _mm512_fmadd_ps(_pA3, _pB3, _sumf); + + pA += 16; + pB += 16; + } + + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 32, _sum2); + _mm512_store_ps(outptr + 48, _sum3); + _mm512_store_ps(outptr + 64, _sum4); + _mm512_store_ps(outptr + 80, _sum5); + _mm512_store_ps(outptr + 96, _sum6); + _mm512_store_ps(outptr + 112, _sum7); + _mm512_store_ps(outptr + 128, _sum8); + _mm512_store_ps(outptr + 128 + 16, _sum9); + _mm512_store_ps(outptr + 128 + 32, _suma); + _mm512_store_ps(outptr + 128 + 48, _sumb); + _mm512_store_ps(outptr + 128 + 64, _sumc); + _mm512_store_ps(outptr + 128 + 80, _sumd); + _mm512_store_ps(outptr + 128 + 96, _sume); + _mm512_store_ps(outptr + 128 + 112, _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + __m512 _sum4 = _mm512_setzero_ps(); + __m512 _sum5 = _mm512_setzero_ps(); + __m512 _sum6 = _mm512_setzero_ps(); + __m512 _sum7 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 32); + _sum3 = _mm512_load_ps(outptr + 48); + _sum4 = _mm512_load_ps(outptr + 64); + _sum5 = _mm512_load_ps(outptr + 80); + _sum6 = _mm512_load_ps(outptr + 96); + _sum7 = _mm512_load_ps(outptr + 112); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pBB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = combine8x2_epi32(_pBB, _pBB); + + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); + + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA0, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA0, (__m512bh)_pB1); + _sum2 = _mm512_dpbf16_ps(_sum2, (__m512bh)_pA1, (__m512bh)_pB0); + _sum3 = _mm512_dpbf16_ps(_sum3, (__m512bh)_pA1, (__m512bh)_pB1); + _sum4 = _mm512_dpbf16_ps(_sum4, (__m512bh)_pA0, (__m512bh)_pB2); + _sum5 = _mm512_dpbf16_ps(_sum5, (__m512bh)_pA0, (__m512bh)_pB3); + _sum6 = _mm512_dpbf16_ps(_sum6, (__m512bh)_pA1, (__m512bh)_pB2); + _sum7 = _mm512_dpbf16_ps(_sum7, (__m512bh)_pA1, (__m512bh)_pB3); + + pA += 32; + pB += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pA)); + __m256 _pBB = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pB)); + __m512 _pB0 = _mm512_castsi512_ps(combine8x2_epi32(_mm256_castps_si256(_pBB), _mm256_castps_si256(_pBB))); + + __m512 _pA1 = _mm512_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512 _pB2 = _mm512_castsi512_ps(_mm512_permutex_epi64(_mm512_castps_si512(_pB0), _MM_SHUFFLE(1, 0, 3, 2))); + __m512 _pB3 = _mm512_permute_ps(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm512_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm512_fmadd_ps(_pA1, _pB1, _sum3); + _sum4 = _mm512_fmadd_ps(_pA0, _pB2, _sum4); + _sum5 = _mm512_fmadd_ps(_pA0, _pB3, _sum5); + _sum6 = _mm512_fmadd_ps(_pA1, _pB2, _sum6); + _sum7 = _mm512_fmadd_ps(_pA1, _pB3, _sum7); + + pA += 16; + pB += 8; + } + + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 32, _sum2); + _mm512_store_ps(outptr + 48, _sum3); + _mm512_store_ps(outptr + 64, _sum4); + _mm512_store_ps(outptr + 80, _sum5); + _mm512_store_ps(outptr + 96, _sum6); + _mm512_store_ps(outptr + 112, _sum7); + outptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 32); + _sum3 = _mm512_load_ps(outptr + 48); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA0, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA0, (__m512bh)_pB1); + _sum2 = _mm512_dpbf16_ps(_sum2, (__m512bh)_pA1, (__m512bh)_pB0); + _sum3 = _mm512_dpbf16_ps(_sum3, (__m512bh)_pA1, (__m512bh)_pB1); + pA += 32; + pB += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pA)); + __m128 _pBs = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + __m512 _pB0 = _mm512_broadcast_f32x4(_pBs); + + __m512 _pA1 = _mm512_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm512_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm512_fmadd_ps(_pA1, _pB1, _sum3); + + pA += 16; + pB += 4; + } + + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 32, _sum2); + _mm512_store_ps(outptr + 48, _sum3); + outptr += 64; + } + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA, (__m512bh)_pB1); + pA += 32; + pB += 4; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pA)); + __m512 _pB0 = bfloat2float_avx512(_mm256_set1_epi32(((const int*)pB)[0])); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + + pA += 16; + pB += 2; + } + + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + outptr += 32; + } + for (; jj < max_jj; jj++) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA, (__m512bh)_mm512_set1_epi32(((const int*)pB)[0])); + pA += 32; + pB += 2; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pA)); + __m512 _pB0 = _mm512_set1_ps(bfloat16_to_float32(pB[0])); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + + pA += 16; + pB += 1; + } + + _mm512_store_ps(outptr, _sum0); + outptr += 16; + } + + pAT += max_kk * 16; + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + __m512 _sum4 = _mm512_setzero_ps(); + __m512 _sum5 = _mm512_setzero_ps(); + __m512 _sum6 = _mm512_setzero_ps(); + __m512 _sum7 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + _sum2 = _mm512_load_ps(outptr + 32); + _sum3 = _mm512_load_ps(outptr + 48); + _sum4 = _mm512_load_ps(outptr + 64); + _sum5 = _mm512_load_ps(outptr + 80); + _sum6 = _mm512_load_ps(outptr + 96); + _sum7 = _mm512_load_ps(outptr + 112); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); + __m512i _pA11 = _mm512_alignr_epi8(_pA00, _pA00, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_alignr_epi8(_pB2, _pB2, 4); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA00, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA00, (__m512bh)_pB1); + _sum2 = _mm512_dpbf16_ps(_sum2, (__m512bh)_pA11, (__m512bh)_pB0); + _sum3 = _mm512_dpbf16_ps(_sum3, (__m512bh)_pA11, (__m512bh)_pB1); + _sum4 = _mm512_dpbf16_ps(_sum4, (__m512bh)_pA00, (__m512bh)_pB2); + _sum5 = _mm512_dpbf16_ps(_sum5, (__m512bh)_pA00, (__m512bh)_pB3); + _sum6 = _mm512_dpbf16_ps(_sum6, (__m512bh)_pA11, (__m512bh)_pB2); + _sum7 = _mm512_dpbf16_ps(_sum7, (__m512bh)_pA11, (__m512bh)_pB3); + pA += 16; + pB += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m256 _pAA = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pA)); + __m512 _pA0 = _mm512_castsi512_ps(combine8x2_epi32(_mm256_castps_si256(_pAA), _mm256_castps_si256(_pAA))); + __m512 _pB0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pB)); + + __m512 _pA1 = _mm512_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512 _pB2 = _mm512_castsi512_ps(_mm512_permutex_epi64(_mm512_castps_si512(_pB0), _MM_SHUFFLE(1, 0, 3, 2))); + __m512 _pB3 = _mm512_permute_ps(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm512_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm512_fmadd_ps(_pA1, _pB1, _sum3); + _sum4 = _mm512_fmadd_ps(_pA0, _pB2, _sum4); + _sum5 = _mm512_fmadd_ps(_pA0, _pB3, _sum5); + _sum6 = _mm512_fmadd_ps(_pA1, _pB2, _sum6); + _sum7 = _mm512_fmadd_ps(_pA1, _pB3, _sum7); + + pA += 8; + pB += 16; + } + + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + _mm512_store_ps(outptr + 32, _sum2); + _mm512_store_ps(outptr + 48, _sum3); + _mm512_store_ps(outptr + 64, _sum4); + _mm512_store_ps(outptr + 80, _sum5); + _mm512_store_ps(outptr + 96, _sum6); + _mm512_store_ps(outptr + 112, _sum7); + outptr += 128; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* pA = pAT; + + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + __m256 _sum4 = _mm256_setzero_ps(); + __m256 _sum5 = _mm256_setzero_ps(); + __m256 _sum6 = _mm256_setzero_ps(); + __m256 _sum7 = _mm256_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); + _sum4 = _mm256_load_ps(outptr + 32); + _sum5 = _mm256_load_ps(outptr + 40); + _sum6 = _mm256_load_ps(outptr + 48); + _sum7 = _mm256_load_ps(outptr + 56); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_alignr_epi8(_pB2, _pB2, 4); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA0, (__m256bh)_pB0); + _sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA0, (__m256bh)_pB1); + _sum2 = _mm256_dpbf16_ps(_sum2, (__m256bh)_pA1, (__m256bh)_pB0); + _sum3 = _mm256_dpbf16_ps(_sum3, (__m256bh)_pA1, (__m256bh)_pB1); + _sum4 = _mm256_dpbf16_ps(_sum4, (__m256bh)_pA0, (__m256bh)_pB2); + _sum5 = _mm256_dpbf16_ps(_sum5, (__m256bh)_pA0, (__m256bh)_pB3); + _sum6 = _mm256_dpbf16_ps(_sum6, (__m256bh)_pA1, (__m256bh)_pB2); + _sum7 = _mm256_dpbf16_ps(_sum7, (__m256bh)_pA1, (__m256bh)_pB3); + pA += 16; + pB += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m256 _pA0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pA)); + __m256 _pB0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pB)); + + __m256 _pA1 = _mm256_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _pB1 = _mm256_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256 _pB2 = _mm256_permute2f128_ps(_pB0, _pB0, _MM_SHUFFLE(0, 0, 0, 1)); + __m256 _pB3 = _mm256_permute_ps(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA1, _pB1, _sum3); + _sum4 = _mm256_comp_fmadd_ps(_pA0, _pB2, _sum4); + _sum5 = _mm256_comp_fmadd_ps(_pA0, _pB3, _sum5); + _sum6 = _mm256_comp_fmadd_ps(_pA1, _pB2, _sum6); + _sum7 = _mm256_comp_fmadd_ps(_pA1, _pB3, _sum7); + + pA += 8; + pB += 8; + } + + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 16, _sum2); + _mm256_store_ps(outptr + 24, _sum3); + _mm256_store_ps(outptr + 32, _sum4); + _mm256_store_ps(outptr + 40, _sum5); + _mm256_store_ps(outptr + 48, _sum6); + _mm256_store_ps(outptr + 56, _sum7); + outptr += 64; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* pA = pAT; + + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pB0 = combine4x2_epi32(_pB, _pB); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 8); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA0, (__m256bh)_pB0); + _sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA0, (__m256bh)_pB1); + _sum2 = _mm256_dpbf16_ps(_sum2, (__m256bh)_pA1, (__m256bh)_pB0); + _sum3 = _mm256_dpbf16_ps(_sum3, (__m256bh)_pA1, (__m256bh)_pB1); + pA += 16; + pB += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m256 _pA0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pA)); + __m128 _pBs = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + __m256 _pB0 = combine4x2_ps(_pBs, _pBs); + + __m256 _pA1 = _mm256_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _pB1 = _mm256_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA1, _pB1, _sum3); + + pA += 8; + pB += 4; + } + + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 16, _sum2); + _mm256_store_ps(outptr + 24, _sum3); + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* pA = pAT; + + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA, (__m256bh)_pB0); + _sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA, (__m256bh)_pB1); + pA += 16; + pB += 4; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m256 _pA0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pA)); + __m256 _pB0 = bfloat2float_avx(_mm_castps_si128(_mm_load1_ps((const float*)pB))); + __m256 _pB1 = _mm256_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA0, _pB1, _sum1); + + pA += 8; + pB += 2; + } + + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + outptr += 16; + } + for (; jj < max_jj; jj++) + { + const unsigned short* pA = pAT; + + __m256 _sum0 = _mm256_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm256_load_ps(outptr); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA, (__m256bh)_mm256_set1_epi32(((const int*)pB)[0])); + pA += 16; + pB += 2; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m256 _pA0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pA)); + __m256 _pB0 = _mm256_set1_ps(bfloat16_to_float32(pB[0])); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + + pA += 8; + pB += 1; + } + + _mm256_store_ps(outptr, _sum0); + outptr += 8; + } + + pAT += max_kk * 8; + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const unsigned short* pA = pAT; + + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_loadu_ps(outptr); + _sum1 = _mm512_loadu_ps(outptr + 16); + _sum2 = _mm512_loadu_ps(outptr + 32); + _sum3 = _mm512_loadu_ps(outptr + 48); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_alignr_epi8(_pA0, _pA0, 8); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA0, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA0, (__m512bh)_pB1); + _sum2 = _mm512_dpbf16_ps(_sum2, (__m512bh)_pA1, (__m512bh)_pB0); + _sum3 = _mm512_dpbf16_ps(_sum3, (__m512bh)_pA1, (__m512bh)_pB1); + pA += 8; + pB += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pAs = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m512 _pA0 = _mm512_broadcast_f32x4(_pAs); + __m512 _pB0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pB)); + + __m512 _pA1 = _mm512_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm512_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm512_fmadd_ps(_pA1, _pB1, _sum3); + + pA += 4; + pB += 16; + } + + _mm512_storeu_ps(outptr, _sum0); + _mm512_storeu_ps(outptr + 16, _sum1); + _mm512_storeu_ps(outptr + 32, _sum2); + _mm512_storeu_ps(outptr + 48, _sum3); + outptr += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* pA = pAT; + +#if __AVX__ + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); +#else + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); + __m128 _sum4 = _mm_setzero_ps(); + __m128 _sum5 = _mm_setzero_ps(); + __m128 _sum6 = _mm_setzero_ps(); + __m128 _sum7 = _mm_setzero_ps(); +#endif + + if (k != 0) + { +#if __AVX__ + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + _sum2 = _mm256_load_ps(outptr + 16); + _sum3 = _mm256_load_ps(outptr + 24); +#else + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); + _sum4 = _mm_load_ps(outptr + 16); + _sum5 = _mm_load_ps(outptr + 20); + _sum6 = _mm_load_ps(outptr + 24); + _sum7 = _mm_load_ps(outptr + 28); +#endif + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m256i _pA00 = combine4x2_epi32(_pA0, _pA0); + __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA11 = _mm256_alignr_epi8(_pA00, _pA00, 8); + __m256i _pB23 = _mm256_alignr_epi8(_pB01, _pB01, 4); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA00, (__m256bh)_pB01); + _sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA11, (__m256bh)_pB01); + _sum2 = _mm256_dpbf16_ps(_sum2, (__m256bh)_pA00, (__m256bh)_pB23); + _sum3 = _mm256_dpbf16_ps(_sum3, (__m256bh)_pA11, (__m256bh)_pB23); + pA += 8; + pB += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { +#if __AVX__ + __m128 _pA = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m256 _pA0 = combine4x2_ps(_pA, _pA); + __m256 _pB0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pB)); + + __m256 _pA1 = _mm256_permute_ps(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _pB1 = _mm256_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA1, _pB0, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA0, _pB1, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA1, _pB1, _sum3); +#else // __AVX__ + __m128 _pA0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m128 _pB0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + __m128 _pB1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(pB + 4))); + + __m128 _pA1 = _mm_shuffle_ps(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128 _pB0s = _mm_shuffle_ps(_pB0, _pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128 _pB1s = _mm_shuffle_ps(_pB1, _pB1, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA1, _pB1, _sum3); + _sum4 = _mm_comp_fmadd_ps(_pA0, _pB0s, _sum4); + _sum5 = _mm_comp_fmadd_ps(_pA0, _pB1s, _sum5); + _sum6 = _mm_comp_fmadd_ps(_pA1, _pB0s, _sum6); + _sum7 = _mm_comp_fmadd_ps(_pA1, _pB1s, _sum7); +#endif // __AVX__ + + pA += 4; + pB += 8; + } + +#if __AVX__ + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + _mm256_store_ps(outptr + 16, _sum2); + _mm256_store_ps(outptr + 24, _sum3); +#else + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 8, _sum2); + _mm_store_ps(outptr + 12, _sum3); + _mm_store_ps(outptr + 16, _sum4); + _mm_store_ps(outptr + 20, _sum5); + _mm_store_ps(outptr + 24, _sum6); + _mm_store_ps(outptr + 28, _sum7); +#endif + outptr += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* pA = pAT; + + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_alignr_epi8(_pA0, _pA0, 8); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA0, (__m128bh)_pB0); + _sum1 = _mm_dpbf16_ps(_sum1, (__m128bh)_pA0, (__m128bh)_pB1); + _sum2 = _mm_dpbf16_ps(_sum2, (__m128bh)_pA1, (__m128bh)_pB0); + _sum3 = _mm_dpbf16_ps(_sum3, (__m128bh)_pA1, (__m128bh)_pB1); + pA += 8; + pB += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pA0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m128 _pB0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + + __m128 _pA1 = _mm_shuffle_ps(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128 _pB1 = _mm_shuffle_ps(_pB0, _pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA1, _pB1, _sum3); + + pA += 4; + pB += 4; + } + + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 8, _sum2); + _mm_store_ps(outptr + 12, _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* pA = pAT; + + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA, (__m128bh)_pB0); + _sum1 = _mm_dpbf16_ps(_sum1, (__m128bh)_pA, (__m128bh)_pB1); + pA += 8; + pB += 4; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pA = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m128 _pB0 = bfloat2float_sse(_mm_castps_si128(_mm_load1_ps((const float*)pB))); + __m128 _pB1 = _mm_shuffle_ps(_pB0, _pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_fmadd_ps(_pA, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _pB1, _sum1); + + pA += 4; + pB += 2; + } + + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const unsigned short* pA = pAT; + + __m128 _sum0 = _mm_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm_load_ps(outptr); + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + _sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA, (__m128bh)_mm_set1_epi32(((const int*)pB)[0])); + pA += 8; + pB += 2; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pA = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pA)); + __m128 _pB0 = _mm_set1_ps(bfloat16_to_float32(pB[0])); + + _sum0 = _mm_comp_fmadd_ps(_pA, _pB0, _sum0); + + pA += 4; + pB += 1; + } + + _mm_store_ps(outptr, _sum0); + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* pB = pBT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_loadu_ps(outptr); + _sum1 = _mm512_loadu_ps(outptr + 16); + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pB1 = _mm512_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA, (__m512bh)_pB0); + _sum1 = _mm512_dpbf16_ps(_sum1, (__m512bh)_pA, (__m512bh)_pB1); + pA += 4; + pB += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = bfloat2float_avx512(_mm256_set1_epi32(((const int*)pA)[0])); + __m512 _pB0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pB)); + __m512 _pB1 = _mm512_permute_ps(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm512_fmadd_ps(_pA0, _pB1, _sum1); + + pA += 2; + pB += 16; + } + + _mm512_storeu_ps(outptr, _sum0); + _mm512_storeu_ps(outptr + 16, _sum1); + outptr += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX__ + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); +#else + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); +#endif + + if (k != 0) + { +#if __AVX__ + _sum0 = _mm256_loadu_ps(outptr); + _sum1 = _mm256_loadu_ps(outptr + 8); +#else + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + _sum2 = _mm_load_ps(outptr + 8); + _sum3 = _mm_load_ps(outptr + 12); +#endif + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA1 = _mm256_alignr_epi8(_pA0, _pA0, 4); + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA0, (__m256bh)_pB); + _sum1 = _mm256_dpbf16_ps(_sum1, (__m256bh)_pA1, (__m256bh)_pB); + pA += 4; + pB += 16; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); +#if __AVX__ + __m256 _pA0 = bfloat2float_avx(_pA); + __m256 _pA1 = _mm256_permute_ps(_pA0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256 _pB0 = bfloat2float_avx(_pB); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA1, _pB0, _sum1); +#else // __AVX__ + __m128 _pA0 = bfloat2float_sse(_pA); + __m128 _pA1 = _mm_shuffle_ps(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128 _pB0 = bfloat2float_sse(_pB); + __m128 _pB1 = bfloat2float_sse(_mm_srli_si128(_pB, 8)); + + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA1, _pB0, _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA1, _pB1, _sum3); +#endif // __AVX__ + + pA += 2; + pB += 8; + } + +#if __AVX__ + _mm256_storeu_ps(outptr, _sum0); + _mm256_storeu_ps(outptr + 8, _sum1); +#else + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 8, _sum2); + _mm_store_ps(outptr + 12, _sum3); +#endif + outptr += 16; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_alignr_epi8(_pB0, _pB0, 4); + _sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA, (__m128bh)_pB0); + _sum1 = _mm_dpbf16_ps(_sum1, (__m128bh)_pA, (__m128bh)_pB1); + pA += 4; + pB += 8; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pA = bfloat2float_sse(_mm_castps_si128(_mm_load1_ps((const float*)pA))); + __m128 _pB0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + __m128 _pB1 = _mm_shuffle_ps(_pB0, _pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_comp_fmadd_ps(_pA, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _pB1, _sum1); + pA += 2; + pB += 4; + } + + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + outptr += 8; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float sum00 = 0.f; + float sum01 = 0.f; + float sum10 = 0.f; + float sum11 = 0.f; + + if (k != 0) + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + float a00 = bfloat16_to_float32(pA[0]); + float a01 = bfloat16_to_float32(pA[1]); + float a10 = bfloat16_to_float32(pA[2]); + float a11 = bfloat16_to_float32(pA[3]); + float b00 = bfloat16_to_float32(pB[0]); + float b01 = bfloat16_to_float32(pB[1]); + float b10 = bfloat16_to_float32(pB[2]); + float b11 = bfloat16_to_float32(pB[3]); + sum00 += a00 * b00 + a01 * b01; + sum01 += a00 * b10 + a01 * b11; + sum10 += a10 * b00 + a11 * b01; + sum11 += a10 * b10 + a11 * b11; + pA += 4; + pB += 4; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + float a0 = bfloat16_to_float32(pA[0]); + float a1 = bfloat16_to_float32(pA[1]); + float b0 = bfloat16_to_float32(pB[0]); + float b1 = bfloat16_to_float32(pB[1]); + sum00 += a0 * b0; + sum01 += a0 * b1; + sum10 += a1 * b0; + sum11 += a1 * b1; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 4; + } + for (; jj < max_jj; jj++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (k != 0) + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + float a00 = bfloat16_to_float32(pA[0]); + float a01 = bfloat16_to_float32(pA[1]); + float a10 = bfloat16_to_float32(pA[2]); + float a11 = bfloat16_to_float32(pA[3]); + float b0 = bfloat16_to_float32(pB[0]); + float b1 = bfloat16_to_float32(pB[1]); + sum0 += a00 * b0 + a01 * b1; + sum1 += a10 * b0 + a11 * b1; + pA += 4; + pB += 2; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + sum0 += bfloat16_to_float32(pA[0]) * bfloat16_to_float32(pB[0]); + sum1 += bfloat16_to_float32(pA[1]) * bfloat16_to_float32(pB[0]); + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii++) + { + const unsigned short* pB = pBT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _sum0 = _mm512_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm512_loadu_ps(outptr); + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB = _mm512_loadu_si512((const __m512i*)pB); + _sum0 = _mm512_dpbf16_ps(_sum0, (__m512bh)_pA, (__m512bh)_pB); + pA += 2; + pB += 32; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m512 _pA0 = _mm512_set1_ps(bfloat16_to_float32(pA[0])); + __m512 _pB0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)pB)); + + _sum0 = _mm512_fmadd_ps(_pA0, _pB0, _sum0); + + pA += 1; + pB += 16; + } + + _mm512_storeu_ps(outptr, _sum0); + outptr += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX__ + __m256 _sum0 = _mm256_setzero_ps(); +#else + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); +#endif + + if (k != 0) + { +#if __AVX__ + _sum0 = _mm256_loadu_ps(outptr); +#else + _sum0 = _mm_loadu_ps(outptr); + _sum1 = _mm_loadu_ps(outptr + 4); +#endif + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ +#if _MSC_VER + __m256 _sum1 = _mm256_setzero_ps(); + __m256i _mask = _mm256_set1_epi32(0xffff0000); +#endif + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_set1_epi32(((const int*)pA)[0]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); +#if _MSC_VER + // msvc crash here --- nihui + __m256 _pA0 = _mm256_castsi256_ps(_mm256_slli_epi32(_pA, 16)); + __m256 _pB0 = _mm256_castsi256_ps(_mm256_slli_epi32(_pB, 16)); + __m256 _pA1 = _mm256_castsi256_ps(_mm256_and_si256(_pA, _mask)); + __m256 _pB1 = _mm256_castsi256_ps(_mm256_and_si256(_pB, _mask)); + _sum0 = _mm256_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm256_fmadd_ps(_pA1, _pB1, _sum1); +#else + _sum0 = _mm256_dpbf16_ps(_sum0, (__m256bh)_pA, (__m256bh)_pB); +#endif + pA += 2; + pB += 16; + } +#if _MSC_VER + _sum0 = _mm256_add_ps(_sum0, _sum1); +#endif +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { +#if __AVX__ + __m256 _pA0 = _mm256_set1_ps(bfloat16_to_float32(pA[0])); + __m256 _pB0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)pB)); + + _sum0 = _mm256_comp_fmadd_ps(_pA0, _pB0, _sum0); +#else + __m128 _pA = _mm_set1_ps(bfloat16_to_float32(pA[0])); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m128i _zero = _mm_setzero_si128(); + __m128 _pB0 = _mm_castsi128_ps(_mm_unpacklo_epi16(_zero, _pB)); + __m128 _pB1 = _mm_castsi128_ps(_mm_unpackhi_epi16(_zero, _pB)); + + _sum0 = _mm_comp_fmadd_ps(_pA, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _pB1, _sum1); +#endif + + pA += 1; + pB += 8; + } + +#if __AVX__ + _mm256_storeu_ps(outptr, _sum0); +#else + _mm_storeu_ps(outptr, _sum0); + _mm_storeu_ps(outptr + 4, _sum1); +#endif + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _sum0 = _mm_setzero_ps(); + + if (k != 0) + { + _sum0 = _mm_loadu_ps(outptr); + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ +#if _MSC_VER + __m128 _sum1 = _mm_setzero_ps(); + __m128i _mask = _mm_set1_epi32(0xffff0000); +#endif + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_set1_epi32(((const int*)pA)[0]); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); +#if _MSC_VER + // msvc crash here --- nihui + __m128 _pA0 = _mm_castsi128_ps(_mm_slli_epi32(_pA, 16)); + __m128 _pB0 = _mm_castsi128_ps(_mm_slli_epi32(_pB, 16)); + __m128 _pA1 = _mm_castsi128_ps(_mm_and_si128(_pA, _mask)); + __m128 _pB1 = _mm_castsi128_ps(_mm_and_si128(_pB, _mask)); + _sum0 = _mm_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_fmadd_ps(_pA1, _pB1, _sum1); +#else + _sum0 = _mm_dpbf16_ps(_sum0, (__m128bh)_pA, (__m128bh)_pB); +#endif + pA += 2; + pB += 8; + } +#if _MSC_VER + _sum0 = _mm_add_ps(_sum0, _sum1); +#endif +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + __m128 _pA0 = _mm_set1_ps(bfloat16_to_float32(pA[0])); + __m128 _pB0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)pB)); + + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + + pA += 1; + pB += 4; + } + + _mm_storeu_ps(outptr, _sum0); + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (k != 0) + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const unsigned short* pA = pAT; + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + float a0 = bfloat16_to_float32(pA[0]); + float a1 = bfloat16_to_float32(pA[1]); + float b00 = bfloat16_to_float32(pB[0]); + float b01 = bfloat16_to_float32(pB[1]); + float b10 = bfloat16_to_float32(pB[2]); + float b11 = bfloat16_to_float32(pB[3]); + sum0 += a0 * b00 + a1 * b01; + sum1 += a0 * b10 + a1 * b11; + pA += 2; + pB += 4; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + float a0 = bfloat16_to_float32(pA[0]); + sum0 += a0 * bfloat16_to_float32(pB[0]); + sum1 += a0 * bfloat16_to_float32(pB[1]); + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const unsigned short* pA = pAT; + + float sum = 0.f; + + if (k != 0) + { + sum = outptr[0]; + } + + int kk = 0; +#if __AVX512BF16__ + for (; kk + 1 < max_kk; kk += 2) + { + float a0 = bfloat16_to_float32(pA[0]); + float a1 = bfloat16_to_float32(pA[1]); + float b00 = bfloat16_to_float32(pB[0]); + float b01 = bfloat16_to_float32(pB[1]); + sum += a0 * b00 + a1 * b01; + pA += 2; + pB += 2; + } +#endif // __AVX512BF16__ + for (; kk < max_kk; kk++) + { + sum += bfloat16_to_float32(pA[0]) * bfloat16_to_float32(pB[0]); + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + + pAT += max_kk; + } +} + +static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, float alpha, float beta, int output_transpose) +{ + // NCNN_LOGE("unpack_output_tile_fp32_to_bf16 %d %d %d %d", i, max_ii, j, max_jj); + const int out_elempack = top_blob.elempack; + const size_t out_hstep = top_blob.dims == 3 ? top_blob.cstep : (size_t)top_blob.w; + + const size_t c_hstep = C.dims == 3 ? C.cstep : (size_t)C.w; + const int c_elempack = C.elempack; + const float* pC = C; + const float* pp = topT; - if (output_transpose) - { - // transpose_unpack: topT layout is [ii][jj] with ii values contiguous for each jj - // output to top_blob which is transposed (j is row, i is col) - int ii = 0; -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - for (; ii + 15 < max_ii; ii += 16) - { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + unsigned short* p0; + if (output_transpose) + { + p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m512 _c0 = _mm512_set1_ps(0.f); + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm512_loadu_ps(pC); + _c0 = _mm512_mul_ps(_c0, _mm512_set1_ps(beta)); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_load_ps(pp); + __m512 _f1 = _mm512_load_ps(pp + 16); + __m512 _f2 = _mm512_load_ps(pp + 32); + __m512 _f3 = _mm512_load_ps(pp + 48); + __m512 _f4 = _mm512_load_ps(pp + 64); + __m512 _f5 = _mm512_load_ps(pp + 80); + __m512 _f6 = _mm512_load_ps(pp + 96); + __m512 _f7 = _mm512_load_ps(pp + 112); + __m512 _f8 = _mm512_load_ps(pp + 128); + __m512 _f9 = _mm512_load_ps(pp + 128 + 16); + __m512 _fa = _mm512_load_ps(pp + 128 + 32); + __m512 _fb = _mm512_load_ps(pp + 128 + 48); + __m512 _fc = _mm512_load_ps(pp + 128 + 64); + __m512 _fd = _mm512_load_ps(pp + 128 + 80); + __m512 _fe = _mm512_load_ps(pp + 128 + 96); + __m512 _ff = _mm512_load_ps(pp + 128 + 112); + pp += 256; + + // deshuffle from the shuffle-based 16x16 dpbf16_ps kernel + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + __m512 _tmp8 = _mm512_unpacklo_ps(_f8, _fb); + __m512 _tmp9 = _mm512_unpackhi_ps(_f8, _fb); + __m512 _tmpa = _mm512_unpacklo_ps(_fa, _f9); + __m512 _tmpb = _mm512_unpackhi_ps(_fa, _f9); + __m512 _tmpc = _mm512_unpacklo_ps(_fc, _ff); + __m512 _tmpd = _mm512_unpackhi_ps(_fc, _ff); + __m512 _tmpe = _mm512_unpacklo_ps(_fe, _fd); + __m512 _tmpf = _mm512_unpackhi_ps(_fe, _fd); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f8 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _f9 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _fa = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fb = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fc = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fd = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fe = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + _ff = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f8, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _fa, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f8, _f0, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f9, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_fa, _f2, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_fb, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp8 = _mm512_shuffle_f32x4(_f4, _fc, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp9 = _mm512_shuffle_f32x4(_f5, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpa = _mm512_shuffle_f32x4(_f6, _fe, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpb = _mm512_shuffle_f32x4(_f7, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpc = _mm512_shuffle_f32x4(_fc, _f4, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpd = _mm512_shuffle_f32x4(_fd, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpe = _mm512_shuffle_f32x4(_fe, _f6, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpf = _mm512_shuffle_f32x4(_ff, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _f8 = _mm512_shuffle_f32x4(_tmp8, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f9 = _mm512_shuffle_f32x4(_tmp9, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _fa = _mm512_shuffle_f32x4(_tmpa, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _fb = _mm512_shuffle_f32x4(_tmpb, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _fc = _mm512_shuffle_f32x4(_tmpc, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _fd = _mm512_shuffle_f32x4(_tmpd, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _fe = _mm512_shuffle_f32x4(_tmpe, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _ff = _mm512_shuffle_f32x4(_tmpf, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + __m512 _c4; + __m512 _c5; + __m512 _c6; + __m512 _c7; + __m512 _c8; + __m512 _c9; + __m512 _ca; + __m512 _cb; + __m512 _cc; + __m512 _cd; + __m512 _ce; + __m512 _cf; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + 64); + _c5 = _mm512_loadu_ps(pC + 80); + _c6 = _mm512_loadu_ps(pC + 96); + _c7 = _mm512_loadu_ps(pC + 112); + _c8 = _mm512_loadu_ps(pC + 128); + _c9 = _mm512_loadu_ps(pC + 128 + 16); + _ca = _mm512_loadu_ps(pC + 128 + 32); + _cb = _mm512_loadu_ps(pC + 128 + 48); + _cc = _mm512_loadu_ps(pC + 128 + 64); + _cd = _mm512_loadu_ps(pC + 128 + 80); + _ce = _mm512_loadu_ps(pC + 128 + 96); + _cf = _mm512_loadu_ps(pC + 128 + 112); + pC += 256; + } + else if (c_elempack == 8) + { + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + 64); + __m512 _tmp5 = _mm512_loadu_ps(pC + 80); + __m512 _tmp6 = _mm512_loadu_ps(pC + 96); + __m512 _tmp7 = _mm512_loadu_ps(pC + 112); + __m512 _tmp8 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _tmp9 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + __m512 _tmpa = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + __m512 _tmpb = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + __m512 _tmpc = _mm512_loadu_ps(pC + c_hstep * 8 + 64); + __m512 _tmpd = _mm512_loadu_ps(pC + c_hstep * 8 + 80); + __m512 _tmpe = _mm512_loadu_ps(pC + c_hstep * 8 + 96); + __m512 _tmpf = _mm512_loadu_ps(pC + c_hstep * 8 + 112); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + _c4 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + _c6 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + _c8 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(1, 0, 1, 0)); + _c9 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(3, 2, 3, 2)); + _ca = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(1, 0, 1, 0)); + _cb = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(3, 2, 3, 2)); + _cc = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + _cd = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + _ce = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + _cf = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(3, 2, 3, 2)); + + pC += 128; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c6 = _mm512_loadu_ps(pC + c_hstep * 4 + 32); + _c7 = _mm512_loadu_ps(pC + c_hstep * 4 + 48); + _c8 = _mm512_loadu_ps(pC + c_hstep * 8); + _c9 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _ca = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + _cb = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + _cc = _mm512_loadu_ps(pC + c_hstep * 12); + _cd = _mm512_loadu_ps(pC + c_hstep * 12 + 16); + _ce = _mm512_loadu_ps(pC + c_hstep * 12 + 32); + _cf = _mm512_loadu_ps(pC + c_hstep * 12 + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c8, _cc, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c8, _cc, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c1, _c5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c9, _cd, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c1, _c5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c9, _cd, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_c2, _c6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_ca, _ce, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpa = _mm512_shuffle_f32x4(_c2, _c6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_ca, _ce, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpc = _mm512_shuffle_f32x4(_c3, _c7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpd = _mm512_shuffle_f32x4(_cb, _cf, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpe = _mm512_shuffle_f32x4(_c3, _c7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpf = _mm512_shuffle_f32x4(_cb, _cf, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _c4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _c8 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _c9 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _ca = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _cb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _cc = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _cd = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _ce = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + _cf = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + pC += 64; + } + else // if (c_elempack == 1) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + c_hstep); + _c2 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3 = _mm512_loadu_ps(pC + c_hstep * 3); + _c4 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5 = _mm512_loadu_ps(pC + c_hstep * 5); + _c6 = _mm512_loadu_ps(pC + c_hstep * 6); + _c7 = _mm512_loadu_ps(pC + c_hstep * 7); + _c8 = _mm512_loadu_ps(pC + c_hstep * 8); + _c9 = _mm512_loadu_ps(pC + c_hstep * 9); + _ca = _mm512_loadu_ps(pC + c_hstep * 10); + _cb = _mm512_loadu_ps(pC + c_hstep * 11); + _cc = _mm512_loadu_ps(pC + c_hstep * 12); + _cd = _mm512_loadu_ps(pC + c_hstep * 13); + _ce = _mm512_loadu_ps(pC + c_hstep * 14); + _cf = _mm512_loadu_ps(pC + c_hstep * 15); + transpose16x16_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + pC += 16; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + _f4 = _mm512_add_ps(_f4, _c4); + _f5 = _mm512_add_ps(_f5, _c5); + _f6 = _mm512_add_ps(_f6, _c6); + _f7 = _mm512_add_ps(_f7, _c7); + _f8 = _mm512_add_ps(_f8, _c8); + _f9 = _mm512_add_ps(_f9, _c9); + _fa = _mm512_add_ps(_fa, _ca); + _fb = _mm512_add_ps(_fb, _cb); + _fc = _mm512_add_ps(_fc, _cc); + _fd = _mm512_add_ps(_fd, _cd); + _fe = _mm512_add_ps(_fe, _ce); + _ff = _mm512_add_ps(_ff, _cf); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7, _beta, _f7); + _f8 = _mm512_fmadd_ps(_c8, _beta, _f8); + _f9 = _mm512_fmadd_ps(_c9, _beta, _f9); + _fa = _mm512_fmadd_ps(_ca, _beta, _fa); + _fb = _mm512_fmadd_ps(_cb, _beta, _fb); + _fc = _mm512_fmadd_ps(_cc, _beta, _fc); + _fd = _mm512_fmadd_ps(_cd, _beta, _fd); + _fe = _mm512_fmadd_ps(_ce, _beta, _fe); + _ff = _mm512_fmadd_ps(_cf, _beta, _ff); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + + _c0 = _mm512_set1_ps(pC[4] * beta); + _c1 = _mm512_set1_ps(pC[5] * beta); + _c2 = _mm512_set1_ps(pC[6] * beta); + _c3 = _mm512_set1_ps(pC[7] * beta); + + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c1); + _f6 = _mm512_add_ps(_f6, _c2); + _f7 = _mm512_add_ps(_f7, _c3); + + _c0 = _mm512_set1_ps(pC[8] * beta); + _c1 = _mm512_set1_ps(pC[9] * beta); + _c2 = _mm512_set1_ps(pC[10] * beta); + _c3 = _mm512_set1_ps(pC[11] * beta); + + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c1); + _fa = _mm512_add_ps(_fa, _c2); + _fb = _mm512_add_ps(_fb, _c3); + + _c0 = _mm512_set1_ps(pC[12] * beta); + _c1 = _mm512_set1_ps(pC[13] * beta); + _c2 = _mm512_set1_ps(pC[14] * beta); + _c3 = _mm512_set1_ps(pC[15] * beta); + + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c1); + _fe = _mm512_add_ps(_fe, _c2); + _ff = _mm512_add_ps(_ff, _c3); + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + _f8 = _mm512_mul_ps(_f8, _alpha); + _f9 = _mm512_mul_ps(_f9, _alpha); + _fa = _mm512_mul_ps(_fa, _alpha); + _fb = _mm512_mul_ps(_fb, _alpha); + _fc = _mm512_mul_ps(_fc, _alpha); + _fd = _mm512_mul_ps(_fd, _alpha); + _fe = _mm512_mul_ps(_fe, _alpha); + _ff = _mm512_mul_ps(_ff, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + __m256i _bf8 = float2bfloat_avx512(_f8); + __m256i _bf9 = float2bfloat_avx512(_f9); + __m256i _bfa = float2bfloat_avx512(_fa); + __m256i _bfb = float2bfloat_avx512(_fb); + __m256i _bfc = float2bfloat_avx512(_fc); + __m256i _bfd = float2bfloat_avx512(_fd); + __m256i _bfe = float2bfloat_avx512(_fe); + __m256i _bff = float2bfloat_avx512(_ff); + + // store bf16 + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); + } + if (out_elempack == 8) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + transpose16x8_epi16(_bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf9, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf9, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bfa, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bfa, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bfb, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bfb, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 8), _mm256_extractf128_si256(_bfc, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 9), _mm256_extractf128_si256(_bfc, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 10), _mm256_extractf128_si256(_bfd, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 11), _mm256_extractf128_si256(_bfd, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 12), _mm256_extractf128_si256(_bfe, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 13), _mm256_extractf128_si256(_bfe, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 14), _mm256_extractf128_si256(_bff, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 15), _mm256_extractf128_si256(_bff, 1)); + } + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); + transpose16x4_epi16(_bf8, _bf9, _bfa, _bfb); + transpose16x4_epi16(_bfc, _bfd, _bfe, _bff); + + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + 24), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p0 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p0 + 40), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p0 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p0 + 48), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 56), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 8), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 16), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 24), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 32), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 40), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 48), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 56), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16), _mm256_extractf128_si256(_bf9, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 24), _mm256_extractf128_si256(_bf9, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 32), _mm256_extractf128_si256(_bfa, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 40), _mm256_extractf128_si256(_bfa, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 48), _mm256_extractf128_si256(_bfb, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 56), _mm256_extractf128_si256(_bfb, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 1))); + + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bfc, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 8), _mm256_extractf128_si256(_bfc, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 16), _mm256_extractf128_si256(_bfd, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 24), _mm256_extractf128_si256(_bfd, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 32), _mm256_extractf128_si256(_bfe, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 40), _mm256_extractf128_si256(_bfe, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 48), _mm256_extractf128_si256(_bff, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 56), _mm256_extractf128_si256(_bff, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 1))); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); + p0 += 256; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf9, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bfa, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bfb, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bfc, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bfd, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bfe, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bff, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 9), _mm256_extractf128_si256(_bf9, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 10), _mm256_extractf128_si256(_bfa, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 11), _mm256_extractf128_si256(_bfb, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 12), _mm256_extractf128_si256(_bfc, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 13), _mm256_extractf128_si256(_bfd, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 14), _mm256_extractf128_si256(_bfe, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 15), _mm256_extractf128_si256(_bff, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf9, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bfa, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bfb, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bfc, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bfd, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bfe, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bff, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 9), _mm256_extractf128_si256(_bf9, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 10), _mm256_extractf128_si256(_bfa, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 11), _mm256_extractf128_si256(_bfb, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 12), _mm256_extractf128_si256(_bfc, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 13), _mm256_extractf128_si256(_bfd, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 14), _mm256_extractf128_si256(_bfe, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 15), _mm256_extractf128_si256(_bff, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); + p0 += 16; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + __m512 _f0 = _mm512_load_ps(pp); + __m512 _f1 = _mm512_load_ps(pp + 16); + __m512 _f2 = _mm512_load_ps(pp + 32); + __m512 _f3 = _mm512_load_ps(pp + 48); + __m512 _f4 = _mm512_load_ps(pp + 64); + __m512 _f5 = _mm512_load_ps(pp + 80); + __m512 _f6 = _mm512_load_ps(pp + 96); + __m512 _f7 = _mm512_load_ps(pp + 112); + pp += 128; + + // deshuffle from the shuffle-based 16x8 dpbf16_ps kernel + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp5 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + __m512 _c4; + __m512 _c5; + __m512 _c6; + __m512 _c7; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + 64); + _c5 = _mm512_loadu_ps(pC + 80); + _c6 = _mm512_loadu_ps(pC + 96); + _c7 = _mm512_loadu_ps(pC + 112); + pC += 128; + } + else if (c_elempack == 8) + { + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _tmp5 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + __m512 _tmp6 = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + __m512 _tmp7 = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + _c4 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + _c6 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + + pC += 64; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + c_hstep * 4); + _c3 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c4 = _mm512_loadu_ps(pC + c_hstep * 8); + _c5 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _c6 = _mm512_loadu_ps(pC + c_hstep * 12); + _c7 = _mm512_loadu_ps(pC + c_hstep * 12 + 16); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c4, _c6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c4, _c6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c1, _c3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c5, _c7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c1, _c3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c5, _c7, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _c4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + pC += 32; + } + else // if (c_elempack == 1) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep); + __m256 _cc2 = _mm256_loadu_ps(pC + c_hstep * 2); + __m256 _cc3 = _mm256_loadu_ps(pC + c_hstep * 3); + __m256 _cc4 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _cc5 = _mm256_loadu_ps(pC + c_hstep * 5); + __m256 _cc6 = _mm256_loadu_ps(pC + c_hstep * 6); + __m256 _cc7 = _mm256_loadu_ps(pC + c_hstep * 7); + __m256 _cc8 = _mm256_loadu_ps(pC + c_hstep * 8); + __m256 _cc9 = _mm256_loadu_ps(pC + c_hstep * 9); + __m256 _cca = _mm256_loadu_ps(pC + c_hstep * 10); + __m256 _ccb = _mm256_loadu_ps(pC + c_hstep * 11); + __m256 _ccc = _mm256_loadu_ps(pC + c_hstep * 12); + __m256 _ccd = _mm256_loadu_ps(pC + c_hstep * 13); + __m256 _cce = _mm256_loadu_ps(pC + c_hstep * 14); + __m256 _ccf = _mm256_loadu_ps(pC + c_hstep * 15); + transpose8x8_ps(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); + transpose8x8_ps(_cc8, _cc9, _cca, _ccb, _ccc, _ccd, _cce, _ccf); + _c0 = combine8x2_ps(_cc0, _cc8); + _c1 = combine8x2_ps(_cc1, _cc9); + _c2 = combine8x2_ps(_cc2, _cca); + _c3 = combine8x2_ps(_cc3, _ccb); + _c4 = combine8x2_ps(_cc4, _ccc); + _c5 = combine8x2_ps(_cc5, _ccd); + _c6 = combine8x2_ps(_cc6, _cce); + _c7 = combine8x2_ps(_cc7, _ccf); + pC += 8; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + _f4 = _mm512_add_ps(_f4, _c4); + _f5 = _mm512_add_ps(_f5, _c5); + _f6 = _mm512_add_ps(_f6, _c6); + _f7 = _mm512_add_ps(_f7, _c7); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + + _c0 = _mm512_set1_ps(pC[4] * beta); + _c1 = _mm512_set1_ps(pC[5] * beta); + _c2 = _mm512_set1_ps(pC[6] * beta); + _c3 = _mm512_set1_ps(pC[7] * beta); + + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c1); + _f6 = _mm512_add_ps(_f6, _c2); + _f7 = _mm512_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + + if (output_transpose) + { + if (out_elempack == 8) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); + } + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); + + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 8), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 10), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 12), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 14), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + p0 += 128; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); + p0 += 64; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + p0 += 32; + } + if (out_elempack == 1) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m512 _f0 = _mm512_load_ps(pp); + __m512 _f1 = _mm512_load_ps(pp + 16); + __m512 _f2 = _mm512_load_ps(pp + 32); + __m512 _f3 = _mm512_load_ps(pp + 48); + pp += 64; + + // deshuffle from the shuffle-based 16x4 dpbf16_ps kernel + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + pC += 64; + } + else if (c_elempack == 8) + { + __m512 _cc0 = _mm512_loadu_ps(pC); + __m512 _cc1 = _mm512_loadu_ps(pC + 16); + __m512 _cc2 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _cc3 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _c0 = _mm512_shuffle_f32x4(_cc0, _cc2, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_cc0, _cc2, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_cc1, _cc3, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_cc1, _cc3, _MM_SHUFFLE(3, 2, 3, 2)); + pC += 32; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + c_hstep * 4); + _c2 = _mm512_loadu_ps(pC + c_hstep * 8); + _c3 = _mm512_loadu_ps(pC + c_hstep * 12); + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c2, _c3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c2, _c3, _MM_SHUFFLE(3, 2, 3, 2)); + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + pC += 16; + } + else // if (c_elempack == 1) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 2); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 3); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 5); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 6); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 7); + __m128 _cc8 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc9 = _mm_loadu_ps(pC + c_hstep * 9); + __m128 _cca = _mm_loadu_ps(pC + c_hstep * 10); + __m128 _ccb = _mm_loadu_ps(pC + c_hstep * 11); + __m128 _ccc = _mm_loadu_ps(pC + c_hstep * 12); + __m128 _ccd = _mm_loadu_ps(pC + c_hstep * 13); + __m128 _cce = _mm_loadu_ps(pC + c_hstep * 14); + __m128 _ccf = _mm_loadu_ps(pC + c_hstep * 15); + _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); + _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); + _MM_TRANSPOSE4_PS(_cc8, _cc9, _cca, _ccb); + _MM_TRANSPOSE4_PS(_ccc, _ccd, _cce, _ccf); + + _c0 = combine4x4_ps(_cc0, _cc4, _cc8, _ccc); + _c1 = combine4x4_ps(_cc1, _cc5, _cc9, _ccd); + _c2 = combine4x4_ps(_cc2, _cc6, _cca, _cce); + _c3 = combine4x4_ps(_cc3, _cc7, _ccb, _ccf); + + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p1 + 128), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p1 + 144), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p1 + 160), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p1 + 176), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p1 + 192), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p1 + 208), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p1 + 224), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p1 + 240), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p1 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p1 + 48), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p1 + 72), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p1 + 80), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p1 + 88), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p1 + 104), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p1 + 112), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p1 + 120), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + p0 += 64; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + p0 += 32; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 16; + } + if (out_elempack == 1) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { + __m512 _f0 = _mm512_load_ps(pp); + __m512 _f1 = _mm512_load_ps(pp + 16); + pp += 32; + + // deshuffle from the shuffle-based 16x2 dpbf16_ps kernel + { + __m512 _tmp0 = _mm512_permute_ps(_f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512 _tmp1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm512_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm512_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + pC += 32; + } + else if (c_elempack == 8) + { + __m512 _cc0 = _mm512_loadu_ps(pC); + __m512 _cc1 = _mm512_loadu_ps(pC + c_hstep * 8); + _c0 = _mm512_shuffle_f32x4(_cc0, _cc1, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_cc0, _cc1, _MM_SHUFFLE(3, 2, 3, 2)); + pC += 16; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + 4); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 4 + 4); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 8 + 4); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 12); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 12 + 4); + _c0 = combine4x4_ps(_cc0, _cc2, _cc4, _cc6); + _c1 = combine4x4_ps(_cc1, _cc3, _cc5, _cc7); + pC += 8; + } + else // if (c_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(c_hstep)); + _c0 = _mm512_i32gather_ps(_vindex, pC, sizeof(float)); + _c1 = _mm512_i32gather_ps(_vindex, pC + 1, sizeof(float)); + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + + if (output_transpose) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + p0 += 32; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + p0 += 8; + } + if (out_elempack == 1) + { + transpose16x2_epi16(_bf0, _bf1); + __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_epi32(p0, _vindex, _bf01, sizeof(unsigned short)); + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { + __m512 _f0 = _mm512_load_ps(pp); + pp += 16; + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + pC += 16; + } + else if (c_elempack == 8) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 8); + _c0 = combine8x2_ps(_cc0, _cc1); + pC += 8; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 12); + _c0 = combine4x4_ps(_cc0, _cc1, _cc2, _cc3); + pC += 4; + } + else // if (c_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(c_hstep)); + _c0 = _mm512_i32gather_ps(_vindex, pC, sizeof(float)); + pC += 1; + } + _f0 = _mm512_fmadd_ps(_c0, _mm512_set1_ps(beta), _f0); + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + + if (output_transpose) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + p0 += out_hstep; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + p0 += 16; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short tmp[16]; + _mm256_storeu_si256((__m256i*)tmp, _bf0); + + p0[0] = tmp[0]; + p0[out_hstep] = tmp[1]; + p0[out_hstep * 2] = tmp[2]; + p0[out_hstep * 3] = tmp[3]; + p0[out_hstep * 4] = tmp[4]; + p0[out_hstep * 5] = tmp[5]; + p0[out_hstep * 6] = tmp[6]; + p0[out_hstep * 7] = tmp[7]; + p0[out_hstep * 8] = tmp[8]; + p0[out_hstep * 9] = tmp[9]; + p0[out_hstep * 10] = tmp[10]; + p0[out_hstep * 11] = tmp[11]; + p0[out_hstep * 12] = tmp[12]; + p0[out_hstep * 13] = tmp[13]; + p0[out_hstep * 14] = tmp[14]; + p0[out_hstep * 15] = tmp[15]; + p0++; + } + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0; + if (output_transpose) + { + p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m256 _c0 = _mm256_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm256_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm256_loadu_ps(pC); + _c0 = _mm256_mul_ps(_c0, _mm256_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x8(_c0); +#endif + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_load_ps(pp); + __m512 _f1 = _mm512_load_ps(pp + 16); + __m512 _f2 = _mm512_load_ps(pp + 32); + __m512 _f3 = _mm512_load_ps(pp + 48); + __m512 _f4 = _mm512_load_ps(pp + 64); + __m512 _f5 = _mm512_load_ps(pp + 80); + __m512 _f6 = _mm512_load_ps(pp + 96); + __m512 _f7 = _mm512_load_ps(pp + 112); + pp += 128; + + // deshuffle from the shuffle-based 8x16 dpbf16_ps kernel + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp2 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp4 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 3) + { + __m512 _c1_avx512; + __m512 _c2_avx512; + __m512 _c3_avx512; + __m512 _c4_avx512; + __m512 _c5_avx512; + __m512 _c6_avx512; + __m512 _c7_avx512; + if (c_elempack == 8) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + _c4_avx512 = _mm512_loadu_ps(pC + 64); + _c5_avx512 = _mm512_loadu_ps(pC + 80); + _c6_avx512 = _mm512_loadu_ps(pC + 96); + _c7_avx512 = _mm512_loadu_ps(pC + 112); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c4_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c0_avx512, _c4_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c1_avx512, _c5_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c1_avx512, _c5_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c2_avx512, _c6_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c2_avx512, _c6_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c3_avx512, _c7_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c3_avx512, _c7_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0_avx512 = _tmp0; + _c1_avx512 = _tmp1; + _c2_avx512 = _tmp2; + _c3_avx512 = _tmp3; + _c4_avx512 = _tmp4; + _c5_avx512 = _tmp5; + _c6_avx512 = _tmp6; + _c7_avx512 = _tmp7; + + pC += 128; + } + else if (c_elempack == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + _c4_avx512 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c6_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 32); + _c7_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c2_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c0_avx512, _c2_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c1_avx512, _c3_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c1_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c4_avx512, _c6_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c4_avx512, _c6_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c5_avx512, _c7_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c5_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 1, 3, 1)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c4_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _c7_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_c0_avx512, _c0_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_c1_avx512, _c1_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_c2_avx512, _c2_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_c3_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c4_avx512 = _mm512_shuffle_f32x4(_c4_avx512, _c4_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_c5_avx512, _c5_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_c6_avx512, _c6_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c7_avx512 = _mm512_shuffle_f32x4(_c7_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + + pC += 64; + } + else // if (c_elempack == 1) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + _c2_avx512 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3_avx512 = _mm512_loadu_ps(pC + c_hstep * 3); + _c4_avx512 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5_avx512 = _mm512_loadu_ps(pC + c_hstep * 5); + _c6_avx512 = _mm512_loadu_ps(pC + c_hstep * 6); + _c7_avx512 = _mm512_loadu_ps(pC + c_hstep * 7); + + __m512 _tmp0 = _mm512_unpacklo_ps(_c0_avx512, _c1_avx512); + __m512 _tmp1 = _mm512_unpacklo_ps(_c2_avx512, _c3_avx512); + __m512 _tmp2 = _mm512_unpacklo_ps(_c4_avx512, _c5_avx512); + __m512 _tmp3 = _mm512_unpacklo_ps(_c6_avx512, _c7_avx512); + __m512 _tmp4 = _mm512_unpackhi_ps(_c0_avx512, _c1_avx512); + __m512 _tmp5 = _mm512_unpackhi_ps(_c2_avx512, _c3_avx512); + __m512 _tmp6 = _mm512_unpackhi_ps(_c4_avx512, _c5_avx512); + __m512 _tmp7 = _mm512_unpackhi_ps(_c6_avx512, _c7_avx512); + + _c0_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c1_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c2_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c3_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c4_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _c5_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _c6_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _c7_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_c4_avx512, _c5_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_c6_avx512, _c7_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_c4_avx512, _c5_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_c6_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _c4_avx512 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _c7_avx512 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + pC += 16; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + _f4 = _mm512_add_ps(_f4, _c4_avx512); + _f5 = _mm512_add_ps(_f5, _c5_avx512); + _f6 = _mm512_add_ps(_f6, _c6_avx512); + _f7 = _mm512_add_ps(_f7, _c7_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2_avx512, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3_avx512, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4_avx512, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5_avx512, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6_avx512, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7_avx512, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + __m512 _cc = _mm512_loadu_ps(pC); + _cc = _mm512_mul_ps(_cc, _mm512_set1_ps(beta)); + __m512 _cc0 = _mm512_permute_ps(_cc, _MM_SHUFFLE(0, 0, 0, 0)); + __m512 _cc1 = _mm512_permute_ps(_cc, _MM_SHUFFLE(1, 1, 1, 1)); + __m512 _cc2 = _mm512_permute_ps(_cc, _MM_SHUFFLE(2, 2, 2, 2)); + __m512 _cc3 = _mm512_permute_ps(_cc, _MM_SHUFFLE(3, 3, 3, 3)); + + _c0_avx512 = _mm512_shuffle_f32x4(_cc0, _cc0, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c1_avx512 = _mm512_shuffle_f32x4(_cc1, _cc1, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c2_avx512 = _mm512_shuffle_f32x4(_cc2, _cc2, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c3_avx512 = _mm512_shuffle_f32x4(_cc3, _cc3, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c4_avx512 = _mm512_shuffle_f32x4(_cc0, _cc0, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c5_avx512 = _mm512_shuffle_f32x4(_cc1, _cc1, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c6_avx512 = _mm512_shuffle_f32x4(_cc2, _cc2, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c7_avx512 = _mm512_shuffle_f32x4(_cc3, _cc3, _MM_SHUFFLE(3, 3, 1, 1)); + + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + _f4 = _mm512_add_ps(_f4, _c4_avx512); + _f5 = _mm512_add_ps(_f5, _c5_avx512); + _f6 = _mm512_add_ps(_f6, _c6_avx512); + _f7 = _mm512_add_ps(_f7, _c7_avx512); + + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_store_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_store_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_store_si128((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_store_si128((__m128i*)(p0 + 16 + 8), _mm256_extractf128_si256(_bf4, 1)); + _mm_store_si128((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf1, 0)); + _mm_store_si128((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf5, 0)); + _mm_store_si128((__m128i*)(p0 + 48), _mm256_extractf128_si256(_bf1, 1)); + _mm_store_si128((__m128i*)(p0 + 48 + 8), _mm256_extractf128_si256(_bf5, 1)); + _mm_store_si128((__m128i*)(p0 + 64), _mm256_extractf128_si256(_bf2, 0)); + _mm_store_si128((__m128i*)(p0 + 64 + 8), _mm256_extractf128_si256(_bf6, 0)); + _mm_store_si128((__m128i*)(p0 + 80), _mm256_extractf128_si256(_bf2, 1)); + _mm_store_si128((__m128i*)(p0 + 80 + 8), _mm256_extractf128_si256(_bf6, 1)); + _mm_store_si128((__m128i*)(p0 + 96), _mm256_extractf128_si256(_bf3, 0)); + _mm_store_si128((__m128i*)(p0 + 96 + 8), _mm256_extractf128_si256(_bf7, 0)); + _mm_store_si128((__m128i*)(p0 + 112), _mm256_extractf128_si256(_bf3, 1)); + _mm_store_si128((__m128i*)(p0 + 112 + 8), _mm256_extractf128_si256(_bf7, 1)); + } + if (out_elempack == 8) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); + __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); + __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); + __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); + transpose8x8_epi16(_bf0l, _bf1l, _bf2l, _bf3l, _bf4l, _bf5l, _bf6l, _bf7l); + _mm_storeu_si128((__m128i*)p0, _bf0l); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7l); + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); + __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); + __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); + __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); + transpose8x8_epi16(_bf0h, _bf1h, _bf2h, _bf3h, _bf4h, _bf5h, _bf6h, _bf7h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _bf0h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _bf1h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _bf2h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _bf3h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _bf4h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _bf5h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _bf6h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _bf7h); + } + if (out_elempack == 4) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); + __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); + __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); + __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); + transpose8x4_epi16(_bf0l, _bf1l, _bf2l, _bf3l); + transpose8x4_epi16(_bf4l, _bf5l, _bf6l, _bf7l); + _mm_storel_epi64((__m128i*)p0, _bf0l); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf0l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf1l); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf1l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf2l); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf2l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf3l); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf3l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf4l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf4l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _bf5l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf5l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _bf6l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf6l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _bf7l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7l)); + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); + __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); + __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); + __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); + transpose8x4_epi16(_bf0h, _bf1h, _bf2h, _bf3h); + transpose8x4_epi16(_bf4h, _bf5h, _bf6h, _bf7h); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _bf0h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_bf0h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _bf1h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_bf1h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _bf2h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 5), _mm_castsi128_pd(_bf2h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _bf3h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 7), _mm_castsi128_pd(_bf3h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _bf4h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_bf4h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _bf5h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_bf5h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 4), _bf6h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_bf6h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 6), _bf7h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_bf7h)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f4, _f5); + __m512 _tmp3 = _mm512_unpacklo_ps(_f6, _f7); + __m512 _tmp4 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp5 = _mm512_unpackhi_ps(_f2, _f3); + __m512 _tmp6 = _mm512_unpackhi_ps(_f4, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f7); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f2 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f5 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _f6 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_storeu_si256((__m256i*)p0, float2bfloat_avx512(_f0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), float2bfloat_avx512(_f1)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), float2bfloat_avx512(_f2)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), float2bfloat_avx512(_f3)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), float2bfloat_avx512(_f4)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), float2bfloat_avx512(_f5)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), float2bfloat_avx512(_f6)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), float2bfloat_avx512(_f7)); + p0 += 16; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + __m256 _f0 = _mm256_load_ps(pp); + __m256 _f1 = _mm256_load_ps(pp + 8); + __m256 _f2 = _mm256_load_ps(pp + 16); + __m256 _f3 = _mm256_load_ps(pp + 24); + __m256 _f4 = _mm256_load_ps(pp + 32); + __m256 _f5 = _mm256_load_ps(pp + 40); + __m256 _f6 = _mm256_load_ps(pp + 48); + __m256 _f7 = _mm256_load_ps(pp + 56); + pp += 64; + + // deshuffle from the shuffle-based 8x8 dpbf16_ps kernel + // from + // 00 11 22 33 04 15 26 37 + // 20 31 02 13 24 35 06 17 + // 01 12 23 30 05 16 27 34 + // 21 32 03 10 25 36 07 14 + // 40 51 62 73 44 55 66 77 + // 60 71 42 53 64 75 46 57 + // 41 52 63 70 45 56 67 74 + // 61 72 43 50 65 76 47 54 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + { + __m256 _tmp0 = _f0; + __m256 _tmp1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp2 = _f2; + __m256 _tmp3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp4 = _f4; + __m256 _tmp5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp6 = _f6; + __m256 _tmp7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_unpacklo_ps(_tmp0, _tmp3); + _f1 = _mm256_unpackhi_ps(_tmp0, _tmp3); + _f2 = _mm256_unpacklo_ps(_tmp2, _tmp1); + _f3 = _mm256_unpackhi_ps(_tmp2, _tmp1); + _f4 = _mm256_unpacklo_ps(_tmp4, _tmp7); + _f5 = _mm256_unpackhi_ps(_tmp4, _tmp7); + _f6 = _mm256_unpacklo_ps(_tmp6, _tmp5); + _f7 = _mm256_unpackhi_ps(_tmp6, _tmp5); + + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp4 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp5 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp6 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + _tmp7 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + + _tmp1 = _mm256_shuffle_ps(_tmp1, _tmp1, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp3 = _mm256_shuffle_ps(_tmp3, _tmp3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp5 = _mm256_shuffle_ps(_tmp5, _tmp5, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp7 = _mm256_shuffle_ps(_tmp7, _tmp7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 0)); + _f1 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 0)); + _f2 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 0)); + _f3 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 0)); + _f4 = _mm256_permute2f128_ps(_tmp4, _tmp0, _MM_SHUFFLE(0, 3, 0, 0)); + _f5 = _mm256_permute2f128_ps(_tmp5, _tmp1, _MM_SHUFFLE(0, 3, 0, 0)); + _f6 = _mm256_permute2f128_ps(_tmp6, _tmp2, _MM_SHUFFLE(0, 3, 0, 0)); + _f7 = _mm256_permute2f128_ps(_tmp7, _tmp3, _MM_SHUFFLE(0, 3, 0, 0)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c0); + _f6 = _mm256_add_ps(_f6, _c0); + _f7 = _mm256_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c0); + _f6 = _mm256_add_ps(_f6, _c0); + _f7 = _mm256_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + __m256 _c2; + __m256 _c3; + __m256 _c4; + __m256 _c5; + __m256 _c6; + __m256 _c7; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + _c2 = _mm256_loadu_ps(pC + 16); + _c3 = _mm256_loadu_ps(pC + 24); + _c4 = _mm256_loadu_ps(pC + 32); + _c5 = _mm256_loadu_ps(pC + 40); + _c6 = _mm256_loadu_ps(pC + 48); + _c7 = _mm256_loadu_ps(pC + 56); + pC += 64; + } + else if (c_elempack == 4) + { + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + __m256 _tmp2 = _mm256_loadu_ps(pC + 16); + __m256 _tmp3 = _mm256_loadu_ps(pC + 24); + __m256 _tmp4 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _tmp5 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + __m256 _tmp6 = _mm256_loadu_ps(pC + c_hstep * 4 + 16); + __m256 _tmp7 = _mm256_loadu_ps(pC + c_hstep * 4 + 24); + _c0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + _c2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _c3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _c4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + _c5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + _c6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _c7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 32; + } + else // if (c_elempack == 1) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + c_hstep); + _c2 = _mm256_loadu_ps(pC + c_hstep * 2); + _c3 = _mm256_loadu_ps(pC + c_hstep * 3); + _c4 = _mm256_loadu_ps(pC + c_hstep * 4); + _c5 = _mm256_loadu_ps(pC + c_hstep * 5); + _c6 = _mm256_loadu_ps(pC + c_hstep * 6); + _c7 = _mm256_loadu_ps(pC + c_hstep * 7); + transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + _f4 = _mm256_add_ps(_f4, _c4); + _f5 = _mm256_add_ps(_f5, _c5); + _f6 = _mm256_add_ps(_f6, _c6); + _f7 = _mm256_add_ps(_f7, _c7); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm256_comp_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm256_comp_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm256_comp_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm256_comp_fmadd_ps(_c7, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + + _c0 = _mm256_set1_ps(pC[4] * beta); + _c1 = _mm256_set1_ps(pC[5] * beta); + _c2 = _mm256_set1_ps(pC[6] * beta); + _c3 = _mm256_set1_ps(pC[7] * beta); + + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c1); + _f6 = _mm256_add_ps(_f6, _c2); + _f7 = _mm256_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + _f2 = _mm256_mul_ps(_f2, _alpha); + _f3 = _mm256_mul_ps(_f3, _alpha); + _f4 = _mm256_mul_ps(_f4, _alpha); + _f5 = _mm256_mul_ps(_f5, _alpha); + _f6 = _mm256_mul_ps(_f6, _alpha); + _f7 = _mm256_mul_ps(_f7, _alpha); + } + + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + __m128i _bf2 = float2bfloat_avx(_f2); + __m128i _bf3 = float2bfloat_avx(_f3); + __m128i _bf4 = float2bfloat_avx(_f4); + __m128i _bf5 = float2bfloat_avx(_f5); + __m128i _bf6 = float2bfloat_avx(_f6); + __m128i _bf7 = float2bfloat_avx(_f7); + + if (output_transpose) + { + if (out_elempack == 8) + { + transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); + } + if (out_elempack == 4) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose8x4_epi16(_bf4, _bf5, _bf6, _bf7); + + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf1); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf2); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf3); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf3)); + + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf4)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _bf5); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf5)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _bf6); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf6)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _bf7); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); + p0 += 64; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf4); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _bf5); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf6); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _bf7); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_bf4)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf5)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_bf6)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m256 _f0 = _mm256_load_ps(pp); + __m256 _f1 = _mm256_load_ps(pp + 8); + __m256 _f2 = _mm256_load_ps(pp + 16); + __m256 _f3 = _mm256_load_ps(pp + 24); + pp += 32; + + // deshuffle from the shuffle-based 8x4 dpbf16_ps kernel + { + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp0 = _mm256_unpacklo_ps(_f0, _f3); + __m256 _tmp1 = _mm256_unpackhi_ps(_f0, _f3); + __m256 _tmp2 = _mm256_unpacklo_ps(_f2, _f1); + __m256 _tmp3 = _mm256_unpackhi_ps(_f2, _f1); + _f0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp2))); + _f1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp2))); + _f2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp3), _mm256_castps_pd(_tmp1))); + _f3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp3), _mm256_castps_pd(_tmp1))); + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + __m256 _c2; + __m256 _c3; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + _c2 = _mm256_loadu_ps(pC + 16); + _c3 = _mm256_loadu_ps(pC + 24); + pC += 32; + } + else if (c_elempack == 4) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + 8); + __m256 _cc2 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _cc3 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + _c0 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 3, 0, 1)); + _c2 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 2, 0, 0)); + _c3 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 16; + } + else // if (c_elempack == 1) + { + // __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + // _c0 = _mm256_i32gather_ps(pC, _vindex, c_hstep * sizeof(float)); + // _c1 = _mm256_i32gather_ps(pC + 1, _vindex, c_hstep * sizeof(float)); + // _c2 = _mm256_i32gather_ps(pC + 2, _vindex, c_hstep * sizeof(float)); + // _c3 = _mm256_i32gather_ps(pC + 3, _vindex, c_hstep * sizeof(float)); + + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 2); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 3); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 5); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 6); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 7); + _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); + _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); + + _c0 = combine4x2_ps(_cc0, _cc4); + _c1 = combine4x2_ps(_cc1, _cc5); + _c2 = combine4x2_ps(_cc2, _cc6); + _c3 = combine4x2_ps(_cc3, _cc7); + + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + _f2 = _mm256_mul_ps(_f2, _alpha); + _f3 = _mm256_mul_ps(_f3, _alpha); + } + + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + __m128i _bf2 = float2bfloat_avx(_f2); + __m128i _bf3 = float2bfloat_avx(_f3); + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p1 + 32), _bf1); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p1 + 64), _bf2); + _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p1 + 96), _bf3); + _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_bf3)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p1 + 16), _bf1); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p1 + 32), _bf2); + _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p1 + 48), _bf3); + _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_bf3)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 16), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 24), _bf3); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + p0 += 32; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf2); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _bf3); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf3)); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { + __m256 _f0 = _mm256_load_ps(pp); + __m256 _f1 = _mm256_load_ps(pp + 8); + pp += 16; + + // deshuffle from the shuffle-based 8x2 dpbf16_ps kernel + { + __m256 _tmp0 = _mm256_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256 _tmp1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm256_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm256_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + pC += 16; + } + else if (c_elempack == 4) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 4); + _c0 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 8; + } + else // if (c_elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32((int)c_hstep)); + _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); + _c1 = _mm256_i32gather_ps(pC + 1, _vindex, sizeof(float)); +#else + _c0 = _mm256_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3], pC[c_hstep * 4], pC[c_hstep * 5], pC[c_hstep * 6], pC[c_hstep * 7]); + _c1 = _mm256_setr_ps(pC[1], pC[c_hstep + 1], pC[c_hstep * 2 + 1], pC[c_hstep * 3 + 1], pC[c_hstep * 4 + 1], pC[c_hstep * 5 + 1], pC[c_hstep * 6 + 1], pC[c_hstep * 7 + 1]); +#endif + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + } + + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + + if (output_transpose) + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + } + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + p0 += 8; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + unsigned short sum1[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + _mm_storeu_si128((__m128i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { + __m256 _f = _mm256_load_ps(pp); + pp += 8; + + if (pC) + { + if (broadcast_type_C == 0) + { + _f = _mm256_add_ps(_f, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f = _mm256_add_ps(_f, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + pC += 8; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); + _c0 = combine4x2_ps(_cc0, _cc1); + pC += 4; + } + else // if (c_elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32((int)c_hstep)); + _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); +#else + _c0 = _mm256_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3], pC[c_hstep * 4], pC[c_hstep * 5], pC[c_hstep * 6], pC[c_hstep * 7]); +#endif + pC += 1; + } + _f = _mm256_comp_fmadd_ps(_c0, _mm256_set1_ps(beta), _f); + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + _f = _mm256_add_ps(_f, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + _f = _mm256_mul_ps(_f, _mm256_set1_ps(alpha)); + } + + __m128i _bf = float2bfloat_avx(_f); + + if (output_transpose) + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf); + } + p0 += out_hstep; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + p0++; + } + } + } + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0; + if (output_transpose) + { + p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m128 _c0 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm_loadu_ps(pC); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x4(_c0); +#endif + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_loadu_ps(pp); + __m512 _f1 = _mm512_loadu_ps(pp + 16); + __m512 _f2 = _mm512_loadu_ps(pp + 32); + __m512 _f3 = _mm512_loadu_ps(pp + 48); + pp += 64; + + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 39 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp2 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp2))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 3) + { + __m512 _c1_avx512; + __m512 _c2_avx512; + __m512 _c3_avx512; + if (c_elempack == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + pC += 64; + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + else // if (c_elempack == 1) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + _c2_avx512 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3_avx512 = _mm512_loadu_ps(pC + c_hstep * 3); + pC += 16; + + __m512 _tmp0 = _mm512_unpacklo_ps(_c0_avx512, _c1_avx512); + __m512 _tmp1 = _mm512_unpacklo_ps(_c2_avx512, _c3_avx512); + __m512 _tmp2 = _mm512_unpackhi_ps(_c0_avx512, _c1_avx512); + __m512 _tmp3 = _mm512_unpackhi_ps(_c2_avx512, _c3_avx512); + _c0_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c1_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c2_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c3_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2_avx512, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3_avx512, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + __m512 _cc = _mm512_loadu_ps(pC); + _cc = _mm512_mul_ps(_cc, _mm512_set1_ps(beta)); + _c0_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(0, 0, 0, 0)); + __m512 _c1_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(1, 1, 1, 1)); + __m512 _c2_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(2, 2, 2, 2)); + __m512 _c3_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(3, 3, 3, 3)); + + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 12), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + if (out_elempack == 8) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 16 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16 + 4), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + if (out_elempack == 4) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + + __m128i _t0 = _mm_unpacklo_epi16(_bf0l, _bf1l); + __m128i _t1 = _mm_unpacklo_epi16(_bf2l, _bf3l); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storel_epi64((__m128i*)p0, _d0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); + + __m128i _t2 = _mm_unpackhi_epi16(_bf0l, _bf1l); + __m128i _t3 = _mm_unpackhi_epi16(_bf2l, _bf3l); + __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); + __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); + + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _t4 = _mm_unpacklo_epi16(_bf0h, _bf1h); + __m128i _t5 = _mm_unpacklo_epi16(_bf2h, _bf3h); + __m128i _d4 = _mm_unpacklo_epi32(_t4, _t5); + __m128i _d5 = _mm_unpackhi_epi32(_t4, _t5); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _d4); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_d4)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _d5); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_d5)); + + __m128i _t6 = _mm_unpackhi_epi16(_bf0h, _bf1h); + __m128i _t7 = _mm_unpackhi_epi16(_bf2h, _bf3h); + __m128i _d6 = _mm_unpacklo_epi32(_t6, _t7); + __m128i _d7 = _mm_unpackhi_epi32(_t6, _t7); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _d6); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_d6)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _d7); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_d7)); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + __m512i _idx_r0r1 = _mm512_set_epi16(61, 45, 29, 13, 57, 41, 25, 9, 53, 37, 21, 5, 49, 33, 17, 1, 60, 44, 28, 12, 56, 40, 24, 8, 52, 36, 20, 4, 48, 32, 16, 0); + __m512i _idx_r2r3 = _mm512_set_epi16(63, 47, 31, 15, 59, 43, 27, 11, 55, 39, 23, 7, 51, 35, 19, 3, 62, 46, 30, 14, 58, 42, 26, 10, 54, 38, 22, 6, 50, 34, 18, 2); + + __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); + __m512i _bf23 = combine8x2_epi32(_bf2, _bf3); + + __m512i _t01 = _mm512_permutex2var_epi16(_bf01, _idx_r0r1, _bf23); + __m512i _t23 = _mm512_permutex2var_epi16(_bf01, _idx_r2r3, _bf23); + + _mm256_storeu_si256((__m256i*)p0, _mm512_extracti32x8_epi32(_t01, 0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _mm512_extracti32x8_epi32(_t01, 1)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _mm512_extracti32x8_epi32(_t23, 0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _mm512_extracti32x8_epi32(_t23, 1)); + p0 += 16; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _f0 = _mm_load_ps(pp); + __m128 _f1 = _mm_load_ps(pp + 4); + __m128 _f2 = _mm_load_ps(pp + 8); + __m128 _f3 = _mm_load_ps(pp + 12); + __m128 _f4 = _mm_load_ps(pp + 16); + __m128 _f5 = _mm_load_ps(pp + 20); + __m128 _f6 = _mm_load_ps(pp + 24); + __m128 _f7 = _mm_load_ps(pp + 28); + pp += 32; + + // from + // 00 11 22 33 + // 04 15 26 37 + // 20 31 02 13 + // 24 35 06 17 + // 01 12 23 30 + // 05 16 27 34 + // 21 32 03 10 + // 25 36 07 14 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + // 04 14 24 34 + // 05 15 25 35 + // 06 16 26 36 + // 07 17 27 37 + { + _f4 = _mm_shuffle_ps(_f4, _f4, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f6 = _mm_shuffle_ps(_f6, _f6, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f6); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f6); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f7); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f7); + __m128 _tmp4 = _mm_unpacklo_ps(_f2, _f4); + __m128 _tmp5 = _mm_unpackhi_ps(_f2, _f4); + __m128 _tmp6 = _mm_unpacklo_ps(_f3, _f5); + __m128 _tmp7 = _mm_unpackhi_ps(_f3, _f5); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f4 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f5 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f6 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f7 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c0); + _f6 = _mm_add_ps(_f6, _c0); + _f7 = _mm_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c0); + _f6 = _mm_add_ps(_f6, _c0); + _f7 = _mm_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + __m128 _c2; + __m128 _c3; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c2 = _mm_loadu_ps(pC + 8); + _c3 = _mm_loadu_ps(pC + 12); + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + _c2 = _mm_loadu_ps(pC + c_hstep * 2); + _c3 = _mm_loadu_ps(pC + c_hstep * 3); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC + 16); + _c1 = _mm_loadu_ps(pC + 20); + _c2 = _mm_loadu_ps(pC + 24); + _c3 = _mm_loadu_ps(pC + 28); + pC += 32; + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC + 4); + _c1 = _mm_loadu_ps(pC + c_hstep + 4); + _c2 = _mm_loadu_ps(pC + c_hstep * 2 + 4); + _c3 = _mm_loadu_ps(pC + c_hstep * 3 + 4); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + pC += 8; + } + if (beta == 1.f) + { + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c1); + _f6 = _mm_add_ps(_f6, _c2); + _f7 = _mm_add_ps(_f7, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f4 = _mm_comp_fmadd_ps(_c0, _beta, _f4); + _f5 = _mm_comp_fmadd_ps(_c1, _beta, _f5); + _f6 = _mm_comp_fmadd_ps(_c2, _beta, _f6); + _f7 = _mm_comp_fmadd_ps(_c3, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + __m128 _c2 = _mm_set1_ps(pC[2] * beta); + __m128 _c3 = _mm_set1_ps(pC[3] * beta); + + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + + _c0 = _mm_set1_ps(pC[4] * beta); + _c1 = _mm_set1_ps(pC[5] * beta); + _c2 = _mm_set1_ps(pC[6] * beta); + _c3 = _mm_set1_ps(pC[7] * beta); + + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c1); + _f6 = _mm_add_ps(_f6, _c2); + _f7 = _mm_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + _f4 = _mm_mul_ps(_f4, _alpha); + _f5 = _mm_mul_ps(_f5, _alpha); + _f6 = _mm_mul_ps(_f6, _alpha); + _f7 = _mm_mul_ps(_f7, _alpha); + } + + __m128i _bf04 = float2bfloat_sse(_f0, _f4); + __m128i _bf15 = float2bfloat_sse(_f1, _f5); + __m128i _bf26 = float2bfloat_sse(_f2, _f6); + __m128i _bf37 = float2bfloat_sse(_f3, _f7); + + if (output_transpose) + { +#if __AVX__ + if (out_elempack == 8) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + _bf04 = _mm_unpacklo_epi32(_t0, _t1); + _bf15 = _mm_unpacklo_epi32(_t2, _t3); + _bf26 = _mm_unpackhi_epi32(_t0, _t1); + _bf37 = _mm_unpackhi_epi32(_t2, _t3); + _t0 = _mm_unpacklo_epi64(_bf04, _bf15); + _t1 = _mm_unpackhi_epi64(_bf04, _bf15); + _t2 = _mm_unpacklo_epi64(_bf26, _bf37); + _t3 = _mm_unpackhi_epi64(_bf26, _bf37); + + _mm_storel_epi64((__m128i*)p0, _t0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_t0)); + _mm_storel_epi64((__m128i*)(p0 + 8), _t1); + _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_t1)); + _mm_storel_epi64((__m128i*)(p0 + 16), _t2); + _mm_storeh_pd((double*)(p0 + 20), _mm_castsi128_pd(_t2)); + _mm_storel_epi64((__m128i*)(p0 + 24), _t3); + _mm_storeh_pd((double*)(p0 + 28), _mm_castsi128_pd(_t3)); + } +#endif // __AVX__ + if (out_elempack == 4) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storel_epi64((__m128i*)p0, _d0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); + + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); + __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf04); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf15); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf26); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _bf37); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf04)); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf15)); + _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_bf26)); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf37)); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf04); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf15); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf26); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf37); + _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_bf04)); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf15)); + _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_bf26)); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf37)); + p0 += 32; + } + if (out_elempack == 1) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + _bf04 = _mm_unpacklo_epi32(_t0, _t1); + _bf15 = _mm_unpacklo_epi32(_t2, _t3); + _bf26 = _mm_unpackhi_epi32(_t0, _t1); + _bf37 = _mm_unpackhi_epi32(_t2, _t3); + _t0 = _mm_unpacklo_epi64(_bf04, _bf15); + _t1 = _mm_unpackhi_epi64(_bf04, _bf15); + _t2 = _mm_unpacklo_epi64(_bf26, _bf37); + _t3 = _mm_unpackhi_epi64(_bf26, _bf37); + + _mm_storeu_si128((__m128i*)p0, _t0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _t1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _t2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _t3); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _f0 = _mm_load_ps(pp); + __m128 _f1 = _mm_load_ps(pp + 4); + __m128 _f2 = _mm_load_ps(pp + 8); + __m128 _f3 = _mm_load_ps(pp + 12); + pp += 16; + + // deshuffle from the shuffle-based 4x4 dpbf16_ps kernel + { + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f3); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f3); + __m128 _tmp2 = _mm_unpacklo_ps(_f2, _f1); + __m128 _tmp3 = _mm_unpackhi_ps(_f2, _f1); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + __m128 _c2; + __m128 _c3; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c2 = _mm_loadu_ps(pC + 8); + _c3 = _mm_loadu_ps(pC + 12); + pC += 16; + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + _c2 = _mm_loadu_ps(pC + c_hstep * 2); + _c3 = _mm_loadu_ps(pC + c_hstep * 3); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + __m128 _c2 = _mm_set1_ps(pC[2] * beta); + __m128 _c3 = _mm_set1_ps(pC[3] * beta); + + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + } + + __m128i _bf02 = float2bfloat_sse(_f0, _f2); + __m128i _bf13 = float2bfloat_sse(_f1, _f3); + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _d0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p1 + 32), _d1); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_d1)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _d0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p1 + 16), _d1); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_d1)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storeu_si128((__m128i*)p0, _d0); + _mm_storeu_si128((__m128i*)(p0 + 8), _d1); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf02); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf13); + _mm_storeh_pd((double*)(p0 + out_hstep * 2), _mm_castsi128_pd(_bf02)); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf02); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf13); + _mm_storeh_pd((double*)(p0 + 4 * 2), _mm_castsi128_pd(_bf02)); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf13)); + p0 += 16; + } + if (out_elempack == 1) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + _bf02 = _mm_unpacklo_epi32(_t0, _t1); + _bf13 = _mm_unpackhi_epi32(_t0, _t1); + + _mm_storel_epi64((__m128i*)(p0), _bf02); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf02)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf13); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { + __m128 _f0 = _mm_load_ps(pp); + __m128 _f1 = _mm_load_ps(pp + 4); + pp += 8; + + // deshuffle from the shuffle-based 4x2 dpbf16_ps kernel + { + __m128 _tmp0 = _mm_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128 _tmp1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + pC += 8; + } + else // if (c_elempack == 1) + { + _c0 = _mm_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3]); + _c1 = _mm_setr_ps(pC[1], pC[c_hstep + 1], pC[c_hstep * 2 + 1], pC[c_hstep * 3 + 1]); + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + __m128i _bf01 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) + { + if (out_elempack == 4) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf01); + + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[4] = sum0[1]; + p0[5] = sum0[5]; + p0[8] = sum0[2]; + p0[9] = sum0[6]; + p0[12] = sum0[3]; + p0[13] = sum0[7]; + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf01); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf01)); + } + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf01); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf01)); + p0 += 8; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf01); + + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum0[5]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum0[6]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum0[7]; + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { + __m128 _f = _mm_load_ps(pp); + pp += 4; + + if (pC) + { + if (broadcast_type_C == 0) + { + _f = _mm_add_ps(_f, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f = _mm_add_ps(_f, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + pC += 4; + } + else // if (c_elempack == 1) + { + _c0 = _mm_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3]); + pC += 1; + } + _f = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f); + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + _f = _mm_add_ps(_f, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + _f = _mm_mul_ps(_f, _mm_set1_ps(alpha)); + } + + __m128i _bf = float2bfloat_sse(_f); + + if (output_transpose) + { + if (out_elempack == 4) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf); + + p0[0] = sum0[0]; + p0[4] = sum0[1]; + p0[4 * 2] = sum0[2]; + p0[4 * 3] = sum0[3]; + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf); + } + p0 += out_hstep; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0++; + } + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + unsigned short* p0; + if (output_transpose) + { + p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + float c0 = 0.f; + float c1 = 0.f; +#if __SSE2__ + __m128 _c0 = _mm_set1_ps(0.f); + __m128 _c1 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); + __m512 _c1_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); + _c1 = _mm_set1_ps(c1); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); + _c1_avx512 = _mm512_set1_ps(c1); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_loadu_ps(pp); + __m512 _f1 = _mm512_loadu_ps(pp + 16); + pp += 32; + + // deshuffle from the shuffle-based 2x16 dpbf16_ps kernel + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + } + pC += 16; + } + if (broadcast_type_C == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c0_avx512 = _mm512_mul_ps(_c0_avx512, _mm512_set1_ps(beta)); + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } - for (int jj = 0; jj < max_jj; jj += 1) + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, float2bfloat_avx512(_mm512_loadu_ps(pp))); - pp += 16; - p0 += out_hstep; + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + } + if (out_elempack == 1) + { + unsigned short sum0[16]; + unsigned short sum1[16]; + _mm256_storeu_si256((__m256i*)sum0, _bf0); + _mm256_storeu_si256((__m256i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + p0[out_hstep * 8] = sum0[8]; + p0[out_hstep * 8 + 1] = sum1[8]; + p0[out_hstep * 9] = sum0[9]; + p0[out_hstep * 9 + 1] = sum1[9]; + p0[out_hstep * 10] = sum0[10]; + p0[out_hstep * 10 + 1] = sum1[10]; + p0[out_hstep * 11] = sum0[11]; + p0[out_hstep * 11 + 1] = sum1[11]; + p0[out_hstep * 12] = sum0[12]; + p0[out_hstep * 12 + 1] = sum1[12]; + p0[out_hstep * 13] = sum0[13]; + p0[out_hstep * 13 + 1] = sum1[13]; + p0[out_hstep * 14] = sum0[14]; + p0[out_hstep * 14 + 1] = sum1[14]; + p0[out_hstep * 15] = sum0[15]; + p0[out_hstep * 15 + 1] = sum1[15]; + } + p0 += out_hstep * 16; + } + else + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + p0 += 16; } } #endif // __AVX512F__ - for (; ii + 7 < max_ii; ii += 8) + for (; jj + 7 < max_jj; jj += 8) { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + __m128 _f0 = _mm_load_ps(pp); + __m128 _f1 = _mm_load_ps(pp + 4); + __m128 _f2 = _mm_load_ps(pp + 8); + __m128 _f3 = _mm_load_ps(pp + 12); + pp += 16; + + // 00 11 02 13 + // 04 15 06 17 + // 10 01 12 03 + // 14 05 16 07 + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); + + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f2); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f2); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f3); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f3); + + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); + _f2 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); - for (int jj = 0; jj < max_jj; jj += 1) + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); + + if (pC) { - _mm_storeu_si128((__m128i*)p0, float2bfloat_avx(_mm256_loadu_ps(pp))); - pp += 8; - p0 += out_hstep; + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c1); + _f3 = _mm_add_ps(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + __m128 _c2 = _mm_loadu_ps(pC + c_hstep); + __m128 _c3 = _mm_loadu_ps(pC + c_hstep + 4); + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); + _c1 = _mm_mul_ps(_c1, _mm_set1_ps(beta)); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + } + + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + __m128i _bf1 = float2bfloat_sse(_f2, _f3); + + if (output_transpose) + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + unsigned short sum1[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + _mm_storeu_si128((__m128i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + p0 += 8; } } -#endif // __AVX__ - for (; ii + 3 < max_ii; ii += 4) +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + __m128 _f0 = _mm_load_ps(pp); + __m128 _f1 = _mm_load_ps(pp + 4); + pp += 8; - for (int jj = 0; jj < max_jj; jj += 1) { - _mm_storel_epi64((__m128i*)p0, float2bfloat_sse(_mm_loadu_ps(pp), _mm_setzero_ps())); - pp += 4; - p0 += out_hstep; + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f1); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f1); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = _mm_loadu_ps(pC); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum0[5]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum0[6]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum0[7]; + } + p0 += out_hstep * 4; + } + else + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm_srli_si128(_bf0, 8)); + p0 += 4; } } #endif // __SSE2__ - for (; ii + 1 < max_ii; ii += 2) + for (; jj + 1 < max_jj; jj += 2) { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + float f00 = pp[0]; + float f01 = pp[1]; + float f10 = pp[2]; + float f11 = pp[3]; + pp += 4; - for (int jj = 0; jj < max_jj; jj += 1) + if (pC) { - p0[0] = float32_to_bfloat16(pp[0]); - p0[1] = float32_to_bfloat16(pp[1]); - pp += 2; - p0 += out_hstep; + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c0; + f10 += c1; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += pC[0] * beta; + f01 += pC[1] * beta; + f10 += pC[c_hstep] * beta; + f11 += pC[c_hstep + 1] * beta; + pC += 2; + } + if (broadcast_type_C == 4) + { + f00 += pC[0] * beta; + f01 += pC[1] * beta; + f10 += pC[0] * beta; + f11 += pC[1] * beta; + pC += 2; + } + } + + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + + unsigned short bf00 = float32_to_bfloat16(f00); + unsigned short bf01 = float32_to_bfloat16(f01); + unsigned short bf10 = float32_to_bfloat16(f10); + unsigned short bf11 = float32_to_bfloat16(f11); + + if (output_transpose) + { + p0[0] = bf00; + p0[1] = bf10; + p0[out_hstep] = bf01; + p0[out_hstep + 1] = bf11; + p0 += out_hstep * 2; + } + else + { + p0[0] = bf00; + p0[1] = bf01; + p0[out_hstep] = bf10; + p0[out_hstep + 1] = bf11; + p0 += 2; } } - for (; ii < max_ii; ii += 1) + for (; jj < max_jj; jj++) { - unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + float f0 = pp[0]; + float f1 = pp[1]; + pp += 2; - for (int jj = 0; jj < max_jj; jj += 1) + if (pC) { - p0[0] = float32_to_bfloat16(pp[0]); - pp += 1; + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + f1 *= alpha; + + unsigned short bf0 = float32_to_bfloat16(f0); + unsigned short bf1 = float32_to_bfloat16(f1); + + if (output_transpose) + { + p0[0] = bf0; + p0[1] = bf1; p0 += out_hstep; } + else + { + p0[0] = bf0; + p0[out_hstep] = bf1; + p0++; + } } } - else + for (; ii < max_ii; ii++) { - // non-transpose unpack: topT layout has ii values contiguous for each jj - // pp[0..ii-1] = results for jj=0, pp[ii..2*ii-1] = results for jj=1, etc. - // output: row (i+ii+k), col (j+jj), with out_elempack packing along the ii dimension - // For bf16 output with out_elempack==1: store pp[k] at row (i+ii+k), col (j+jj) - int ii = 0; + unsigned short* p0; + if (output_transpose) + { + p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + float c0 = 0.f; #if __SSE2__ -#if __AVX__ + __m128 _c0 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) #if __AVX512F__ - for (; ii + 15 < max_ii; ii += 16) + for (; jj + 15 < max_jj; jj += 16) { - for (int jj = 0; jj < max_jj; jj += 1) + __m512 _f0 = _mm512_loadu_ps(pp); + pp += 16; + + if (pC) { - for (int k = 0; k < 16; k++) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - *((unsigned short*)top_blob + (i + ii + k) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[k]); + _f0 = _mm512_add_ps(_f0, _c0_avx512); } - pp += 16; + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _f0 = _mm512_fmadd_ps(_c0_avx512, _mm512_set1_ps(beta), _f0); + pC += 16; + } + } + + if (alpha != 1.f) + { + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + } + + __m256i _bf0 = float2bfloat_avx512(_f0); + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + } + if (out_elempack == 1) + { + unsigned short sum0[16]; + _mm256_storeu_si256((__m256i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 8] = sum0[8]; + p0[out_hstep * 9] = sum0[9]; + p0[out_hstep * 10] = sum0[10]; + p0[out_hstep * 11] = sum0[11]; + p0[out_hstep * 12] = sum0[12]; + p0[out_hstep * 13] = sum0[13]; + p0[out_hstep * 14] = sum0[14]; + p0[out_hstep * 15] = sum0[15]; + } + } + p0 += out_hstep * 16; + } + else + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + p0 += 16; } } #endif // __AVX512F__ - for (; ii + 7 < max_ii; ii += 8) + for (; jj + 7 < max_jj; jj += 8) { - for (int jj = 0; jj < max_jj; jj += 1) + __m128 _f0 = _mm_loadu_ps(pp); + __m128 _f1 = _mm_loadu_ps(pp + 4); + pp += 8; + + if (pC) { - for (int k = 0; k < 8; k++) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - *((unsigned short*)top_blob + (i + ii + k) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[k]); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = _mm_loadu_ps(pC); + __m128 _c1 = _mm_loadu_ps(pC + 4); + _f0 = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _mm_set1_ps(beta), _f1); + pC += 8; } - pp += 8; } - } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + } + else + { +#if __AVX__ + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + } #endif // __AVX__ - for (; ii + 3 < max_ii; ii += 4) + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, float2bfloat_sse(_f0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), float2bfloat_sse(_f1)); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + } + } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_si128((__m128i*)p0, _bf0); + p0 += 8; + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) { - for (int jj = 0; jj < max_jj; jj += 1) + __m128 _f0 = _mm_loadu_ps(pp); + pp += 4; + + if (pC) { - *((unsigned short*)top_blob + (i + ii + 0) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[0]); - *((unsigned short*)top_blob + (i + ii + 1) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[1]); - *((unsigned short*)top_blob + (i + ii + 2) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[2]); - *((unsigned short*)top_blob + (i + ii + 3) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[3]); - pp += 4; + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _f0 = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f0); + pC += 4; + } + } + + _f0 = _mm_mul_ps(_f0, _mm_set1_ps(alpha)); + + __m128i _bf0 = float2bfloat_sse(_f0); + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + } + else + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + _mm_storel_epi64((__m128i*)(p0 - (jj % 16) / 4 * out_hstep * 4 + (jj % 16) / 4 * 4), _bf0); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + _mm_storel_epi64((__m128i*)(p0 - (jj % 8) / 4 * out_hstep * 4 + (jj % 8) / 4 * 4), _bf0); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + } + if (out_elempack == 1) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + } + } + p0 += out_hstep * 4; + } + else + { + _mm_storel_epi64((__m128i*)p0, _bf0); + p0 += 4; } } #endif // __SSE2__ - for (; ii + 1 < max_ii; ii += 2) + for (; jj + 1 < max_jj; jj += 2) { - for (int jj = 0; jj < max_jj; jj += 1) + float f0 = pp[0]; + float f1 = pp[1]; + pp += 2; + + if (pC) { - *((unsigned short*)top_blob + (i + ii + 0) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[0]); - *((unsigned short*)top_blob + (i + ii + 1) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[1]); - pp += 2; + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[1] * beta; + pC += 2; + } + } + + f0 *= alpha; + f1 *= alpha; + + unsigned short bf0 = float32_to_bfloat16(f0); + unsigned short bf1 = float32_to_bfloat16(f1); + + if (output_transpose) + { + p0[0] = bf0; + p0[out_hstep] = bf1; + p0 += out_hstep * 2; + } + else + { + p0[0] = bf0; + p0[1] = bf1; + p0 += 2; } } - for (; ii < max_ii; ii += 1) + for (; jj < max_jj; jj++) { - for (int jj = 0; jj < max_jj; jj += 1) + float f0 = pp[0]; + pp += 1; + + if (pC) { - *((unsigned short*)top_blob + (i + ii) * out_hstep + (j + jj) * out_elempack) = float32_to_bfloat16(pp[0]); - pp += 1; + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = float32_to_bfloat16(f0); + + if (output_transpose) + { + p0 += out_hstep; + } + else + { + p0++; } } } } -static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +static void get_optimal_tile_mnk_bf16(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) { - // resolve optimal tile size from cache size const size_t l2_cache_size = get_cpu_level2_cache_size(); if (nT == 0) nT = get_physical_big_cpu_count(); - int tile_size = (int)sqrt((float)l2_cache_size / 3 / sizeof(float)); + // bf16 takes half the space of fp32 plus fp32 accumulator + int tile_size = (int)sqrt((float)l2_cache_size / (2 * sizeof(unsigned short) + sizeof(float))); #if __AVX512F__ TILE_M = std::max(16, tile_size / 16 * 16); @@ -2265,7 +8796,7 @@ static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, TILE_K = std::max(16, tile_size / 16 * 16); #elif __AVX__ TILE_M = std::max(8, tile_size / 8 * 8); - TILE_N = std::max(4, tile_size / 4 * 4); + TILE_N = std::max(8, tile_size / 8 * 8); TILE_K = std::max(8, tile_size / 8 * 8); #elif __SSE2__ TILE_M = std::max(4, tile_size / 4 * 4); @@ -2292,14 +8823,14 @@ static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, if (nn_K == 1) { - tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); + tile_size = (int)((float)l2_cache_size / 2 / sizeof(unsigned short) / TILE_K); #if __AVX512F__ TILE_M = std::max(16, tile_size / 16 * 16); TILE_N = std::max(16, tile_size / 16 * 16); #elif __AVX__ TILE_M = std::max(8, tile_size / 8 * 8); - TILE_N = std::max(4, tile_size / 4 * 4); + TILE_N = std::max(8, tile_size / 8 * 8); #elif __SSE2__ TILE_M = std::max(4, tile_size / 4 * 4); TILE_N = std::max(4, tile_size / 4 * 4); @@ -2332,7 +8863,7 @@ static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, #if __AVX512F__ TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 15) / 16 * 16); #elif __AVX__ - TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 7) / 8 * 8); #elif __SSE2__ TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); #else @@ -2353,7 +8884,6 @@ static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, #endif } - // always take constant TILE_M/N/K value when provided if (constant_TILE_M > 0) { #if __AVX512F__ @@ -2372,7 +8902,7 @@ static void get_optimal_tile_mnk_bf16s(int M, int N, int K, int constant_TILE_M, #if __AVX512F__ TILE_N = (constant_TILE_N + 15) / 16 * 16; #elif __AVX__ - TILE_N = (constant_TILE_N + 3) / 4 * 4; + TILE_N = (constant_TILE_N + 7) / 8 * 8; #elif __SSE2__ TILE_N = (constant_TILE_N + 3) / 4 * 4; #else diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index 0fda36ccd1d..a5acea649f9 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -7218,6 +7218,13 @@ int Gemm_x86::create_pipeline(const Option& opt) } #endif +#if NCNN_BF16 + if (opt.use_bf16_storage) + { + return create_pipeline_bf16s(opt); + } +#endif + if (constantA) { const int M = constantM; @@ -8381,137 +8388,148 @@ void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& t } // namespace Gemm_x86_utility #if NCNN_BF16 -static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +int Gemm_x86::create_pipeline_bf16s(const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + if (constantA) + { + const int M = constantM; + const int K = constantK; - int TILE_M, TILE_N, TILE_K; - get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); - const int nn_M = (M + TILE_M - 1) / TILE_M; - const int nn_N = (N + TILE_N - 1) / TILE_N; - const int nn_K = (K + TILE_K - 1) / TILE_K; + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_K = (K + TILE_K - 1) / TILE_K; - // BT is fp32 packed tile - Mat BT(TILE_K * TILE_N, nn_K, nn_N, 4u, opt.workspace_allocator); - if (BT.empty()) - return -100; + // cast A_data fp32 to bf16 + Mat A_data_bf16; + cast_float32_to_bfloat16(A_data, A_data_bf16, opt); - // pack B (bf16 -> fp32) - const int nn_NK = nn_N * nn_K; - #pragma omp parallel for num_threads(nT) - for (int ppjk = 0; ppjk < nn_NK; ppjk++) - { - const int ppj = ppjk / nn_K; - const int ppk = ppjk % nn_K; + AT_data.create(TILE_K * TILE_M, nn_K, nn_M, 2u, (Allocator*)0); + if (AT_data.empty()) + return -100; - const int j = ppj * TILE_N; - const int k = ppk * TILE_K; + const int nn_MK = nn_M * nn_K; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppik = 0; ppik < nn_MK; ppik++) + { + const int ppi = ppik / nn_K; + const int ppk = ppik % nn_K; - const int max_jj = std::min((N - j), TILE_N); - const int max_kk = std::min((K - k), TILE_K); + const int i = ppi * TILE_M; + const int k = ppk * TILE_K; - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); - if (transB) - { - pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); - } - else - { - transpose_pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); - } - } + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); - // topT is always needed for bf16 path (accumulate fp32, then convert to bf16) - Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); - if (topT.empty()) - return -100; + if (transA) + { + transpose_pack_A_tile_bf16(A_data_bf16, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16(A_data_bf16, AT_tile, i, max_ii, k, max_kk); + } + } - Mat ATX(TILE_K * TILE_M, nn_K, nT, 4u, opt.workspace_allocator); - if (ATX.empty()) - return -100; + if (opt.lightmode) + A_data.release(); + } - #pragma omp parallel for num_threads(nT) - for (int ppi = 0; ppi < nn_M; ppi++) + if (constantB) { - const int i = ppi * TILE_M; + const int N = constantN; + const int K = constantK; - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); - const int max_ii = std::min((M - i), TILE_M); + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; - Mat topT_tile = topT.channel(get_omp_thread_num()); + // cast B_data fp32 to bf16 + Mat B_data_bf16; + cast_float32_to_bfloat16(B_data, B_data_bf16, opt); - for (int j = 0; j < N; j += TILE_N) + BT_data.create(TILE_K * TILE_N, nn_K, nn_N, 2u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); - if (broadcast_type_C == 3) + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); + pack_B_tile_bf16(B_data_bf16, BT_tile, j, max_jj, k, max_kk); } - - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - - for (int k = 0; k < K; k += TILE_K) + else { - const int max_kk = std::min((K - k), TILE_K); - - Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + transpose_pack_B_tile_bf16(B_data_bf16, BT_tile, j, max_jj, k, max_kk); + } + } - Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + if (opt.lightmode) + B_data.release(); + } - if (j == 0) - { - if (transA) - { - transpose_pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); - } - else - { - pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); - } - } + if (constantC && constant_broadcast_type_C != -1) + { + CT_data = C_data; - // always k_end=false, accumulate to topT as fp32 - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); - } +#if __SSE2__ + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) + { +#if __AVX512F__ + int C_elempack = constantM % 16 == 0 ? 16 : constantM % 8 == 0 ? 8 : constantM % 4 == 0 ? 4 : 1; +#elif __AVX__ + int C_elempack = constantM % 8 == 0 ? 8 : constantM % 4 == 0 ? 4 : 1; +#else + int C_elempack = constantM % 4 == 0 ? 4 : 1; +#endif + convert_packing(C_data, CT_data, C_elempack, opt); + if (CT_data.empty()) + return -100; + } +#endif // __SSE2__ - // multiply alpha - if (alpha != 1.f) - { - float* outptr = topT_tile; - int size = max_ii * max_jj; - for (int q = 0; q < size; q++) - { - outptr[q] *= alpha; - } - } + if (opt.lightmode) + C_data.release(); + } - // convert fp32 topT to bf16 output - unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); - } + if (constantA || constantB || constantC) + { + nT = opt.num_threads; } return 0; } -static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; int TILE_M, TILE_N, TILE_K; - get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; const int nn_N = (N + TILE_N - 1) / TILE_N; const int nn_K = (K + TILE_K - 1) / TILE_K; - Mat BT(TILE_K * TILE_N, nn_K, nn_N, 4u, opt.workspace_allocator); + Mat BT(TILE_K * TILE_N, nn_K, nn_N, 2u, opt.workspace_allocator); if (BT.empty()) return -100; @@ -8532,11 +8550,11 @@ static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top if (transB) { - pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + pack_B_tile_bf16(B, BT_tile, j, max_jj, k, max_kk); } else { - transpose_pack_B_tile_bf16s(B, BT_tile, j, max_jj, k, max_kk); + transpose_pack_B_tile_bf16(B, BT_tile, j, max_jj, k, max_kk); } } @@ -8557,48 +8575,31 @@ static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top { const int max_jj = std::min((N - j), TILE_N); - if (broadcast_type_C == 3) - { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); - } - - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) { const int max_kk = std::min((K - k), TILE_K); - // AT is pre-packed fp32 + // AT is pre-packed bf16 Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - if (alpha != 1.f) - { - float* outptr = topT_tile; - int size = max_ii * max_jj; - for (int q = 0; q < size; q++) - { - outptr[q] *= alpha; - } - } - - unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); } } return 0; } -static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; int TILE_M, TILE_N, TILE_K; - get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; const int nn_K = (K + TILE_K - 1) / TILE_K; @@ -8607,7 +8608,7 @@ static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top if (topT.empty()) return -100; - Mat ATX(TILE_K * TILE_M, nn_K, nT, 4u, opt.workspace_allocator); + Mat ATX(TILE_K * TILE_M, nn_K, nT, 2u, opt.workspace_allocator); if (ATX.empty()) return -100; @@ -8627,58 +8628,41 @@ static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top { const int max_jj = std::min((N - j), TILE_N); - if (broadcast_type_C == 3) - { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); - } - - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) { const int max_kk = std::min((K - k), TILE_K); Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); - // BT is pre-packed fp32 + // BT is pre-packed bf16 Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); if (j == 0) { if (transA) { - transpose_pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + transpose_pack_A_tile_bf16(A, AT_tile, i, max_ii, k, max_kk); } else { - pack_A_tile_bf16s(A, AT_tile, i, max_ii, k, max_kk); + pack_A_tile_bf16(A, AT_tile, i, max_ii, k, max_kk); } } - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); - } - - if (alpha != 1.f) - { - float* outptr = topT_tile; - int size = max_ii * max_jj; - for (int q = 0; q < size; q++) - { - outptr[q] *= alpha; - } + gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); } } return 0; } -static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { int TILE_M, TILE_N, TILE_K; - get_optimal_tile_mnk_bf16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); const int nn_M = (M + TILE_M - 1) / TILE_M; @@ -8699,13 +8683,6 @@ static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& { const int max_jj = std::min((N - j), TILE_N); - if (broadcast_type_C == 3) - { - pack_A_tile(C, topT_tile, i, max_ii, j, max_jj); - } - - const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; - for (int k = 0; k < K; k += TILE_K) { const int max_kk = std::min((K - k), TILE_K); @@ -8713,20 +8690,105 @@ static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); - gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, false /*k_end*/); + gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - if (alpha != 1.f) + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); + } + } + + return 0; +} + +static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, nn_K, nn_N, 2u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_bf16(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_bf16(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + Mat ATX(TILE_K * TILE_M, nn_K, nT, 2u, opt.workspace_allocator); + if (ATX.empty()) + return -100; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) { - float* outptr = topT_tile; - int size = max_ii * max_jj; - for (int q = 0; q < size; q++) + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) { - outptr[q] *= alpha; + if (transA) + { + transpose_pack_A_tile_bf16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16(A, AT_tile, i, max_ii, k, max_kk); + } } + + gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_bf16s(topT_tile, top_blob, i, max_ii, j, max_jj, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); } } @@ -8827,28 +8889,25 @@ int Gemm_x86::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_INT8 diff --git a/src/layer/x86/gemm_x86_avx512bf16.cpp b/src/layer/x86/gemm_x86_avx512bf16.cpp new file mode 100644 index 00000000000..700e343ab15 --- /dev/null +++ b/src/layer/x86/gemm_x86_avx512bf16.cpp @@ -0,0 +1,43 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_bf16s.h" + +void pack_A_tile_bf16_avx512bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_bf16(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_bf16_avx512bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_bf16(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_bf16_avx512bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_bf16(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_bf16_avx512bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_bf16(B, BT, j, max_jj, k, max_kk); +} + +void gemm_transB_packed_tile_bf16s_avx512bf16(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 0da0b001f93..911df1fd176 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -406,6 +406,24 @@ static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& return _v; } +static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v) +{ +#if __AVX512BF16__ + __m128i _v = (__m128i)_mm_cvtneps_pbh(v); +#else + __m128i _a = _mm_castps_si128(v); +#if __SSE4_1__ + _a = _mm_srli_epi32(_a, 16); + __m128i _v = _mm_packus_epi32(_a, _mm_setzero_si128()); +#else + _a = _mm_shufflelo_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + _a = _mm_shufflehi_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + __m128i _v = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_a), _mm_setzero_ps(), _MM_SHUFFLE(2, 0, 2, 0))); +#endif +#endif + return _v; +} + static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { #if __FMA__ @@ -465,6 +483,20 @@ static NCNN_FORCEINLINE __m128i _mm_comp_dpwssd_epi32(const __m128i& src, const } #if __AVX__ +static NCNN_FORCEINLINE __m256 combine4x2_ps(const __m128& a, const __m128& b) +{ + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); +} + +static NCNN_FORCEINLINE __m256i combine4x2_epi32(const __m128i& a, const __m128i& b) +{ +#if __AVX2__ + return _mm256_inserti128_si256(_mm256_castsi128_si256(a), b, 1); +#else + return _mm256_insertf128_si256(_mm256_castsi128_si256(a), b, 1); +#endif +} + static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return a * b + c @@ -843,6 +875,81 @@ static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) _r1 = _mm256_permute2f128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); } +static NCNN_FORCEINLINE void transpose16x4_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3) +{ +#if __AVX2__ + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + + __m256i _tmp4 = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmp5 = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmp6 = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmp7 = _mm256_unpackhi_epi32(_tmp1, _tmp3); + + _r0 = _mm256_permute2f128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r3 = _mm256_permute2f128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +#else + __m128i _r0l = _mm256_extractf128_si256(_r0, 0); + __m128i _r0h = _mm256_extractf128_si256(_r0, 1); + __m128i _r1l = _mm256_extractf128_si256(_r1, 0); + __m128i _r1h = _mm256_extractf128_si256(_r1, 1); + __m128i _r2l = _mm256_extractf128_si256(_r2, 0); + __m128i _r2h = _mm256_extractf128_si256(_r2, 1); + __m128i _r3l = _mm256_extractf128_si256(_r3, 0); + __m128i _r3h = _mm256_extractf128_si256(_r3, 1); + + __m128i _tmp0l = _mm_unpacklo_epi16(_r0l, _r1l); + __m128i _tmp1l = _mm_unpackhi_epi16(_r0l, _r1l); + __m128i _tmp2l = _mm_unpacklo_epi16(_r2l, _r3l); + __m128i _tmp3l = _mm_unpackhi_epi16(_r2l, _r3l); + + __m128i _tmp0h = _mm_unpacklo_epi16(_r0h, _r1h); + __m128i _tmp1h = _mm_unpackhi_epi16(_r0h, _r1h); + __m128i _tmp2h = _mm_unpacklo_epi16(_r2h, _r3h); + __m128i _tmp3h = _mm_unpackhi_epi16(_r2h, _r3h); + + __m128i _t0 = _mm_unpacklo_epi32(_tmp0l, _tmp2l); + __m128i _t1 = _mm_unpackhi_epi32(_tmp0l, _tmp2l); + __m128i _t2 = _mm_unpacklo_epi32(_tmp1l, _tmp3l); + __m128i _t3 = _mm_unpackhi_epi32(_tmp1l, _tmp3l); + + __m128i _t4 = _mm_unpacklo_epi32(_tmp0h, _tmp2h); + __m128i _t5 = _mm_unpackhi_epi32(_tmp0h, _tmp2h); + __m128i _t6 = _mm_unpacklo_epi32(_tmp1h, _tmp3h); + __m128i _t7 = _mm_unpackhi_epi32(_tmp1h, _tmp3h); + + _r0 = combine4x2_epi32(_t0, _t1); + _r1 = combine4x2_epi32(_t2, _t3); + _r2 = combine4x2_epi32(_t4, _t5); + _r3 = combine4x2_epi32(_t6, _t7); +#endif +} + +static NCNN_FORCEINLINE void transpose16x2_epi16(__m256i& _r0, __m256i& _r1) +{ +#if __AVX2__ + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + _r0 = _mm256_permute2f128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +#else + __m128i _r0l = _mm256_extractf128_si256(_r0, 0); + __m128i _r0h = _mm256_extractf128_si256(_r0, 1); + __m128i _r1l = _mm256_extractf128_si256(_r1, 0); + __m128i _r1h = _mm256_extractf128_si256(_r1, 1); + __m128i _t0l = _mm_unpacklo_epi16(_r0l, _r1l); + __m128i _t0h = _mm_unpackhi_epi16(_r0l, _r1l); + __m128i _t1l = _mm_unpacklo_epi16(_r0h, _r1h); + __m128i _t1h = _mm_unpackhi_epi16(_r0h, _r1h); + _r0 = combine4x2_epi32(_t0l, _t0h); + _r1 = combine4x2_epi32(_t1l, _t1h); +#endif +} + static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7) { const __m256 s01 = _mm256_hadd_ps(v0, v1); @@ -880,20 +987,6 @@ static NCNN_FORCEINLINE __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2 _mm256_castps256_ps128(s0123)); } -static NCNN_FORCEINLINE __m256 combine4x2_ps(const __m128& a, const __m128& b) -{ - return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); -} - -static NCNN_FORCEINLINE __m256i combine4x2_epi32(const __m128i& a, const __m128i& b) -{ -#if __AVX2__ - return _mm256_inserti128_si256(_mm256_castsi128_si256(a), b, 1); -#else - return _mm256_insertf128_si256(_mm256_castsi128_si256(a), b, 1); -#endif -} - static NCNN_FORCEINLINE float _mm256_reduce_add_ps(const __m256& x) { /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ From 7fa7a37eac90dee5d491a50c56c6e7a55993f063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kenji=20Mouri=20/=20=E6=AF=9B=E5=88=A9=20=E7=A0=94?= =?UTF-8?q?=E4=BA=8C?= Date: Mon, 30 Mar 2026 10:21:46 +0800 Subject: [PATCH 28/36] Add benchmark results for several instances from Microsoft Azure. (#6552) --- benchmark/README.md | 1493 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1493 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index 1c1110f760e..3251df385fd 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -9701,3 +9701,1496 @@ cooling_down = 1 FastestDet min = 51.52 max = 62.65 avg = 55.04 ``` +### Microsoft Azure Standard D64ps v6 Instance + +- Type: 64 vcpu, 256 GiB RAM +- CPU: Azure Cobalt 100 (Neoverse-N2) @ 3.4GHz + - Note: lscpu or something like fastfetch only report Neoverse-N2, the "Azure + Cobalt 100" name is only mentioned in the Microsoft document. + - Note: CPU frequency is measured by 7-Zip Benchmark. +- OS: Debian 12 with Kernel 6.1.0-42-cloud-arm64 and GCC 12.2.0 +- ncnn version tag: 20260113 +- ncnn configuration command + > cmake -B build -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_BENCHMARK=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_VULKAN=OFF + +``` +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 1 0 -1 0 +loop_count = 512 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 6.39 max = 6.47 avg = 6.41 + squeezenet_int8 min = 4.69 max = 4.77 avg = 4.71 + mobilenet min = 11.58 max = 11.67 avg = 11.60 + mobilenet_int8 min = 6.33 max = 6.41 avg = 6.35 + mobilenet_v2 min = 7.29 max = 7.38 avg = 7.31 + mobilenet_v3 min = 6.03 max = 6.11 avg = 6.05 + shufflenet min = 4.13 max = 4.19 avg = 4.15 + shufflenet_v2 min = 4.19 max = 4.22 avg = 4.20 + mnasnet min = 7.18 max = 7.28 avg = 7.20 + proxylessnasnet min = 8.58 max = 8.89 avg = 8.66 + efficientnet_b0 min = 14.03 max = 14.45 avg = 14.25 + efficientnetv2_b0 min = 16.07 max = 16.67 avg = 16.47 + regnety_400m min = 9.79 max = 10.33 avg = 10.03 + blazeface min = 1.51 max = 1.57 avg = 1.53 + googlenet min = 25.70 max = 26.45 avg = 25.96 + googlenet_int8 min = 18.80 max = 19.49 avg = 19.16 + resnet18 min = 17.45 max = 18.33 avg = 17.91 + resnet18_int8 min = 13.69 max = 14.21 avg = 13.96 + alexnet min = 18.27 max = 19.47 avg = 19.10 + vgg16 min = 98.73 max = 101.91 avg = 100.58 + vgg16_int8 min = 101.59 max = 105.64 avg = 103.66 + resnet50 min = 53.92 max = 55.47 avg = 54.74 + resnet50_int8 min = 30.40 max = 30.96 avg = 30.65 + squeezenet_ssd min = 16.09 max = 16.64 avg = 16.31 + squeezenet_ssd_int8 min = 13.32 max = 13.91 avg = 13.66 + mobilenet_ssd min = 24.45 max = 24.95 avg = 24.67 + mobilenet_ssd_int8 min = 13.34 max = 13.88 avg = 13.56 + mobilenet_yolo min = 55.07 max = 56.16 avg = 55.47 + mobilenetv2_yolov3 min = 27.47 max = 28.03 avg = 27.72 + yolov4-tiny min = 34.41 max = 35.77 avg = 34.80 + nanodet_m min = 10.18 max = 10.49 avg = 10.24 + yolo-fastest-1.1 min = 4.17 max = 4.26 avg = 4.18 + yolo-fastestv2 min = 3.48 max = 3.56 avg = 3.50 + vision_transformer min = 793.71 max = 800.32 avg = 796.59 + FastestDet min = 3.62 max = 3.76 avg = 3.65 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 2 0 -1 0 +loop_count = 512 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.54 max = 3.72 avg = 3.59 + squeezenet_int8 min = 2.71 max = 2.83 avg = 2.76 + mobilenet min = 6.06 max = 6.28 avg = 6.15 + mobilenet_int8 min = 3.44 max = 3.60 avg = 3.50 + mobilenet_v2 min = 4.08 max = 4.30 avg = 4.16 + mobilenet_v3 min = 3.47 max = 3.69 avg = 3.54 + shufflenet min = 2.74 max = 2.85 avg = 2.79 + shufflenet_v2 min = 2.61 max = 2.92 avg = 2.65 + mnasnet min = 4.09 max = 4.29 avg = 4.17 + proxylessnasnet min = 4.77 max = 4.97 avg = 4.84 + efficientnet_b0 min = 7.75 max = 8.01 avg = 7.85 + efficientnetv2_b0 min = 8.95 max = 9.22 avg = 9.05 + regnety_400m min = 6.77 max = 7.23 avg = 7.02 + blazeface min = 1.00 max = 1.22 avg = 1.01 + googlenet min = 13.25 max = 14.07 avg = 13.62 + googlenet_int8 min = 10.35 max = 10.60 avg = 10.43 + resnet18 min = 8.78 max = 9.35 avg = 8.89 + resnet18_int8 min = 7.03 max = 7.25 avg = 7.09 + alexnet min = 8.98 max = 9.64 avg = 9.22 + vgg16 min = 50.37 max = 51.34 avg = 50.73 + vgg16_int8 min = 52.02 max = 56.48 avg = 52.99 + resnet50 min = 27.38 max = 28.07 avg = 27.68 + resnet50_int8 min = 15.05 max = 15.91 avg = 15.25 + squeezenet_ssd min = 9.05 max = 9.53 avg = 9.23 + squeezenet_ssd_int8 min = 7.91 max = 8.10 avg = 7.97 + mobilenet_ssd min = 12.55 max = 12.80 avg = 12.61 + mobilenet_ssd_int8 min = 6.99 max = 7.27 avg = 7.04 + mobilenet_yolo min = 29.69 max = 30.18 avg = 29.87 + mobilenetv2_yolov3 min = 15.63 max = 15.96 avg = 15.71 + yolov4-tiny min = 19.00 max = 19.79 avg = 19.29 + nanodet_m min = 6.35 max = 6.65 avg = 6.40 + yolo-fastest-1.1 min = 2.94 max = 3.07 avg = 2.97 + yolo-fastestv2 min = 2.51 max = 2.61 avg = 2.54 + vision_transformer min = 424.01 max = 447.52 avg = 434.38 + FastestDet min = 2.56 max = 2.71 avg = 2.59 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 4 0 -1 0 +loop_count = 512 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.04 max = 2.46 avg = 2.07 + squeezenet_int8 min = 1.72 max = 1.81 avg = 1.74 + mobilenet min = 3.11 max = 3.23 avg = 3.14 + mobilenet_int8 min = 1.92 max = 2.02 avg = 1.94 + mobilenet_v2 min = 2.31 max = 2.46 avg = 2.35 + mobilenet_v3 min = 2.15 max = 2.34 avg = 2.18 + shufflenet min = 2.08 max = 2.29 avg = 2.11 + shufflenet_v2 min = 1.71 max = 1.96 avg = 1.73 + mnasnet min = 2.43 max = 2.61 avg = 2.46 + proxylessnasnet min = 2.72 max = 2.84 avg = 2.75 + efficientnet_b0 min = 4.29 max = 4.43 avg = 4.32 + efficientnetv2_b0 min = 5.17 max = 5.39 avg = 5.22 + regnety_400m min = 5.22 max = 5.41 avg = 5.27 + blazeface min = 0.75 max = 1.06 avg = 0.76 + googlenet min = 7.14 max = 7.42 avg = 7.21 + googlenet_int8 min = 5.97 max = 6.23 avg = 6.06 + resnet18 min = 4.69 max = 4.95 avg = 4.78 + resnet18_int8 min = 3.97 max = 4.12 avg = 4.02 + alexnet min = 4.67 max = 5.12 avg = 4.98 + vgg16 min = 25.81 max = 26.52 avg = 26.16 + vgg16_int8 min = 26.35 max = 27.23 avg = 26.72 + resnet50 min = 14.21 max = 14.82 avg = 14.47 + resnet50_int8 min = 8.33 max = 8.98 avg = 8.59 + squeezenet_ssd min = 5.86 max = 6.29 avg = 6.03 + squeezenet_ssd_int8 min = 4.93 max = 5.29 avg = 5.02 + mobilenet_ssd min = 6.81 max = 7.17 avg = 6.95 + mobilenet_ssd_int8 min = 4.04 max = 4.31 avg = 4.18 + mobilenet_yolo min = 17.36 max = 17.67 avg = 17.49 + mobilenetv2_yolov3 min = 9.04 max = 9.34 avg = 9.12 + yolov4-tiny min = 12.01 max = 12.36 avg = 12.16 + nanodet_m min = 4.14 max = 4.28 avg = 4.18 + yolo-fastest-1.1 min = 2.33 max = 2.47 avg = 2.37 + yolo-fastestv2 min = 1.90 max = 2.03 avg = 1.93 + vision_transformer min = 225.61 max = 241.58 avg = 233.70 + FastestDet min = 1.88 max = 2.21 avg = 1.90 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 8 0 -1 0 +loop_count = 512 +num_threads = 8 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.40 max = 1.64 avg = 1.43 + squeezenet_int8 min = 1.33 max = 1.42 avg = 1.35 + mobilenet min = 1.80 max = 1.91 avg = 1.83 + mobilenet_int8 min = 1.29 max = 1.37 avg = 1.32 + mobilenet_v2 min = 1.79 max = 1.98 avg = 1.83 + mobilenet_v3 min = 1.57 max = 1.74 avg = 1.60 + shufflenet min = 1.79 max = 2.40 avg = 1.82 + shufflenet_v2 min = 1.33 max = 1.75 avg = 1.35 + mnasnet min = 1.70 max = 1.97 avg = 1.72 + proxylessnasnet min = 1.89 max = 2.01 avg = 1.91 + efficientnet_b0 min = 2.97 max = 3.13 avg = 3.01 + efficientnetv2_b0 min = 3.77 max = 3.92 avg = 3.83 + regnety_400m min = 4.80 max = 4.95 avg = 4.85 + blazeface min = 0.75 max = 0.94 avg = 0.76 + googlenet min = 4.62 max = 4.85 avg = 4.70 + googlenet_int8 min = 4.08 max = 4.27 avg = 4.14 + resnet18 min = 2.62 max = 2.86 avg = 2.69 + resnet18_int8 min = 2.38 max = 2.60 avg = 2.44 + alexnet min = 2.64 max = 2.90 avg = 2.77 + vgg16 min = 13.77 max = 14.28 avg = 14.05 + vgg16_int8 min = 13.76 max = 14.31 avg = 14.02 + resnet50 min = 7.94 max = 8.27 avg = 8.09 + resnet50_int8 min = 5.11 max = 5.58 avg = 5.27 + squeezenet_ssd min = 4.38 max = 4.74 avg = 4.51 + squeezenet_ssd_int8 min = 3.79 max = 4.08 avg = 3.85 + mobilenet_ssd min = 4.23 max = 4.33 avg = 4.27 + mobilenet_ssd_int8 min = 2.86 max = 2.99 avg = 2.92 + mobilenet_yolo min = 11.38 max = 11.78 avg = 11.60 + mobilenetv2_yolov3 min = 6.80 max = 7.22 avg = 6.96 + yolov4-tiny min = 8.72 max = 9.03 avg = 8.80 + nanodet_m min = 3.36 max = 3.58 avg = 3.43 + yolo-fastest-1.1 min = 2.23 max = 2.48 avg = 2.26 + yolo-fastestv2 min = 1.82 max = 1.99 avg = 1.84 + vision_transformer min = 124.81 max = 133.32 avg = 129.22 + FastestDet min = 1.77 max = 2.12 avg = 1.80 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 16 0 -1 0 +loop_count = 512 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.27 max = 1.51 avg = 1.29 + squeezenet_int8 min = 1.33 max = 1.42 avg = 1.36 + mobilenet min = 1.23 max = 1.40 avg = 1.24 + mobilenet_int8 min = 1.11 max = 1.19 avg = 1.13 + mobilenet_v2 min = 1.63 max = 1.84 avg = 1.66 + mobilenet_v3 min = 1.39 max = 1.77 avg = 1.43 + shufflenet min = 1.88 max = 2.28 avg = 1.92 + shufflenet_v2 min = 1.36 max = 1.80 avg = 1.39 + mnasnet min = 1.38 max = 1.65 avg = 1.42 + proxylessnasnet min = 1.54 max = 1.85 avg = 1.57 + efficientnet_b0 min = 2.49 max = 2.85 avg = 2.53 + efficientnetv2_b0 min = 3.25 max = 3.59 avg = 3.35 + regnety_400m min = 5.05 max = 5.34 avg = 5.12 + blazeface min = 0.80 max = 0.88 avg = 0.81 + googlenet min = 3.48 max = 3.68 avg = 3.52 + googlenet_int8 min = 3.40 max = 3.61 avg = 3.45 + resnet18 min = 1.93 max = 2.10 avg = 1.98 + resnet18_int8 min = 2.02 max = 2.13 avg = 2.05 + alexnet min = 1.68 max = 1.84 avg = 1.77 + vgg16 min = 8.42 max = 8.63 avg = 8.50 + vgg16_int8 min = 8.23 max = 8.57 avg = 8.34 + resnet50 min = 5.10 max = 5.88 avg = 5.30 + resnet50_int8 min = 3.84 max = 4.04 avg = 3.91 + squeezenet_ssd min = 4.20 max = 4.41 avg = 4.28 + squeezenet_ssd_int8 min = 3.51 max = 3.82 avg = 3.59 + mobilenet_ssd min = 3.07 max = 3.37 avg = 3.18 + mobilenet_ssd_int8 min = 2.51 max = 2.75 avg = 2.57 + mobilenet_yolo min = 9.64 max = 10.24 avg = 9.80 + mobilenetv2_yolov3 min = 5.51 max = 5.74 avg = 5.60 + yolov4-tiny min = 7.65 max = 8.06 avg = 7.75 + nanodet_m min = 3.01 max = 3.21 avg = 3.07 + yolo-fastest-1.1 min = 2.21 max = 2.44 avg = 2.26 + yolo-fastestv2 min = 1.90 max = 2.22 avg = 1.95 + vision_transformer min = 73.19 max = 81.55 avg = 75.35 + FastestDet min = 1.75 max = 2.00 avg = 1.78 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 32 0 -1 0 +loop_count = 512 +num_threads = 32 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.27 max = 1.38 avg = 1.29 + squeezenet_int8 min = 1.47 max = 1.56 avg = 1.50 + mobilenet min = 1.09 max = 1.30 avg = 1.12 + mobilenet_int8 min = 1.16 max = 1.26 avg = 1.19 + mobilenet_v2 min = 1.65 max = 1.78 avg = 1.68 + mobilenet_v3 min = 1.57 max = 1.66 avg = 1.59 + shufflenet min = 2.30 max = 2.42 avg = 2.34 + shufflenet_v2 min = 1.57 max = 1.68 avg = 1.60 + mnasnet min = 1.50 max = 1.63 avg = 1.53 + proxylessnasnet min = 1.69 max = 1.80 avg = 1.71 + efficientnet_b0 min = 2.52 max = 2.72 avg = 2.56 + efficientnetv2_b0 min = 3.67 max = 3.85 avg = 3.74 + regnety_400m min = 6.96 max = 7.28 avg = 7.12 + blazeface min = 1.04 max = 1.12 avg = 1.06 + googlenet min = 3.49 max = 3.82 avg = 3.56 + googlenet_int8 min = 3.61 max = 3.93 avg = 3.68 + resnet18 min = 1.96 max = 2.20 avg = 2.00 + resnet18_int8 min = 2.20 max = 2.31 avg = 2.24 + alexnet min = 1.62 max = 1.74 avg = 1.66 + vgg16 min = 6.69 max = 7.15 avg = 6.76 + vgg16_int8 min = 6.65 max = 7.07 avg = 6.74 + resnet50 min = 4.38 max = 4.61 avg = 4.45 + resnet50_int8 min = 3.81 max = 4.03 avg = 3.86 + squeezenet_ssd min = 4.72 max = 5.11 avg = 4.83 + squeezenet_ssd_int8 min = 4.06 max = 4.29 avg = 4.11 + mobilenet_ssd min = 3.05 max = 3.25 avg = 3.10 + mobilenet_ssd_int8 min = 2.80 max = 2.99 avg = 2.85 + mobilenet_yolo min = 10.68 max = 11.15 avg = 10.83 + mobilenetv2_yolov3 min = 5.30 max = 5.60 avg = 5.38 + yolov4-tiny min = 7.47 max = 7.76 avg = 7.55 + nanodet_m min = 3.50 max = 3.69 avg = 3.56 + yolo-fastest-1.1 min = 2.60 max = 2.78 avg = 2.63 + yolo-fastestv2 min = 2.31 max = 2.44 avg = 2.35 + vision_transformer min = 46.21 max = 49.20 avg = 47.48 + FastestDet min = 2.16 max = 2.27 avg = 2.20 +misaki@HimiMisakiBenchmarkARM64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 64 0 -1 0 +loop_count = 512 +num_threads = 64 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.52 max = 2.16 avg = 1.57 + squeezenet_int8 min = 1.79 max = 19.05 avg = 2.06 + mobilenet min = 1.24 max = 2.50 avg = 1.29 + mobilenet_int8 min = 1.41 max = 1.90 avg = 1.45 + mobilenet_v2 min = 2.11 max = 2.90 avg = 2.21 + mobilenet_v3 min = 2.07 max = 29.34 avg = 2.24 + shufflenet min = 3.08 max = 66.45 avg = 3.42 + shufflenet_v2 min = 2.11 max = 3.50 avg = 2.18 + mnasnet min = 1.81 max = 2.20 avg = 1.85 + proxylessnasnet min = 2.15 max = 2.87 avg = 2.23 + efficientnet_b0 min = 3.25 max = 4.02 avg = 3.34 + efficientnetv2_b0 min = 5.02 max = 23.60 avg = 5.33 + regnety_400m min = 11.57 max = 75.67 avg = 12.69 + blazeface min = 1.49 max = 11.01 avg = 1.56 + googlenet min = 4.08 max = 4.78 avg = 4.18 + googlenet_int8 min = 4.58 max = 9.01 avg = 4.69 + resnet18 min = 2.32 max = 19.52 avg = 2.45 + resnet18_int8 min = 2.70 max = 3.12 avg = 2.79 + alexnet min = 2.16 max = 4.41 avg = 2.25 + vgg16 min = 7.44 max = 9.87 avg = 7.60 + vgg16_int8 min = 7.86 max = 13.28 avg = 8.26 + resnet50 min = 4.98 max = 5.61 avg = 5.07 + resnet50_int8 min = 4.51 max = 18.84 avg = 4.64 + squeezenet_ssd min = 6.33 max = 18.64 avg = 6.65 + squeezenet_ssd_int8 min = 5.22 max = 66.17 avg = 5.88 + mobilenet_ssd min = 3.90 max = 4.42 avg = 4.02 + mobilenet_ssd_int8 min = 3.67 max = 37.05 avg = 4.08 + mobilenet_yolo min = 17.69 max = 52.61 avg = 18.20 + mobilenetv2_yolov3 min = 6.26 max = 12.46 avg = 6.42 + yolov4-tiny min = 9.02 max = 13.73 avg = 9.31 + nanodet_m min = 4.21 max = 4.88 avg = 4.33 + yolo-fastest-1.1 min = 3.34 max = 3.74 avg = 3.42 + yolo-fastestv2 min = 3.23 max = 30.29 avg = 3.43 + vision_transformer min = 48.63 max = 57.14 avg = 52.05 + FastestDet min = 2.88 max = 3.50 avg = 2.95 +``` + +### Microsoft Azure Standard D64s v6 Instance + +- Type: 64 vcpu, 256 GiB RAM +- CPU: INTEL XEON PLATINUM 8573C @ 3.6GHz +- OS: Debian 12 with Kernel 6.1.0-43-cloud-amd64 and GCC 12.2.0 +- ncnn version tag: 20260113 +- ncnn configuration command + - With AVX512 + > cmake -B build -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_BENCHMARK=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_VULKAN=OFF + - Without AVX512 + > cmake -B build -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_BENCHMARK=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_VULKAN=OFF -DNCNN_AVX512=OFF -DNCNN_AVX512VNNI=OFF + +**Results (With AVX512)** + +``` +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 1 0 -1 0 +loop_count = 512 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 4.87 max = 5.13 avg = 4.89 + squeezenet_int8 min = 4.49 max = 4.57 avg = 4.52 + mobilenet min = 7.91 max = 8.02 avg = 7.94 + mobilenet_int8 min = 7.04 max = 7.15 avg = 7.07 + mobilenet_v2 min = 6.31 max = 6.42 avg = 6.33 + mobilenet_v3 min = 4.96 max = 5.02 avg = 4.99 + shufflenet min = 3.23 max = 3.32 avg = 3.26 + shufflenet_v2 min = 3.33 max = 3.40 avg = 3.36 + mnasnet min = 5.85 max = 5.97 avg = 5.87 + proxylessnasnet min = 6.59 max = 6.70 avg = 6.62 + efficientnet_b0 min = 9.42 max = 11.31 avg = 9.46 + efficientnetv2_b0 min = 10.40 max = 10.56 avg = 10.45 + regnety_400m min = 7.91 max = 10.56 avg = 7.95 + blazeface min = 1.07 max = 1.12 avg = 1.08 + googlenet min = 17.27 max = 17.59 avg = 17.34 + googlenet_int8 min = 12.27 max = 12.47 avg = 12.34 + resnet18 min = 14.53 max = 15.05 avg = 14.58 + resnet18_int8 min = 9.80 max = 10.22 avg = 9.90 + alexnet min = 11.49 max = 12.46 avg = 11.64 + vgg16 min = 102.94 max = 109.09 avg = 104.08 + vgg16_int8 min = 78.74 max = 83.18 avg = 79.86 + resnet50 min = 38.27 max = 42.39 avg = 39.06 + resnet50_int8 min = 23.34 max = 24.53 avg = 23.52 + squeezenet_ssd min = 13.16 max = 13.29 avg = 13.19 + squeezenet_ssd_int8 min = 12.25 max = 12.34 avg = 12.27 + mobilenet_ssd min = 16.45 max = 16.56 avg = 16.48 + mobilenet_ssd_int8 min = 13.86 max = 15.55 avg = 13.92 + mobilenet_yolo min = 37.41 max = 37.80 avg = 37.59 + mobilenetv2_yolov3 min = 22.35 max = 22.53 avg = 22.42 + yolov4-tiny min = 29.03 max = 31.28 avg = 29.24 + nanodet_m min = 8.50 max = 8.63 avg = 8.54 + yolo-fastest-1.1 min = 3.82 max = 3.88 avg = 3.85 + yolo-fastestv2 min = 3.72 max = 3.81 avg = 3.74 + vision_transformer min = 663.72 max = 670.45 avg = 665.62 + FastestDet min = 3.72 max = 4.68 avg = 3.75 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 2 0 -1 0 +loop_count = 512 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.26 max = 3.49 avg = 3.31 + squeezenet_int8 min = 2.90 max = 4.09 avg = 2.94 + mobilenet min = 4.47 max = 4.56 avg = 4.50 + mobilenet_int8 min = 3.84 max = 3.94 avg = 3.88 + mobilenet_v2 min = 4.19 max = 4.30 avg = 4.24 + mobilenet_v3 min = 3.81 max = 3.92 avg = 3.87 + shufflenet min = 3.42 max = 5.23 avg = 3.47 + shufflenet_v2 min = 2.87 max = 2.96 avg = 2.91 + mnasnet min = 4.06 max = 4.20 avg = 4.10 + proxylessnasnet min = 4.48 max = 4.59 avg = 4.53 + efficientnet_b0 min = 5.94 max = 7.13 avg = 6.02 + efficientnetv2_b0 min = 6.97 max = 8.24 avg = 7.05 + regnety_400m min = 7.06 max = 8.03 avg = 7.12 + blazeface min = 1.03 max = 1.11 avg = 1.06 + googlenet min = 10.97 max = 11.28 avg = 11.07 + googlenet_int8 min = 7.52 max = 7.68 avg = 7.59 + resnet18 min = 8.32 max = 8.61 avg = 8.36 + resnet18_int8 min = 5.35 max = 5.50 avg = 5.39 + alexnet min = 6.36 max = 6.62 avg = 6.41 + vgg16 min = 53.24 max = 56.32 avg = 53.93 + vgg16_int8 min = 41.55 max = 44.76 avg = 42.65 + resnet50 min = 21.58 max = 23.22 avg = 21.88 + resnet50_int8 min = 12.47 max = 13.24 avg = 12.57 + squeezenet_ssd min = 8.84 max = 9.05 avg = 8.91 + squeezenet_ssd_int8 min = 7.75 max = 8.10 avg = 7.81 + mobilenet_ssd min = 9.47 max = 9.58 avg = 9.52 + mobilenet_ssd_int8 min = 7.66 max = 7.84 avg = 7.73 + mobilenet_yolo min = 21.73 max = 22.14 avg = 21.88 + mobilenetv2_yolov3 min = 14.36 max = 14.64 avg = 14.46 + yolov4-tiny min = 18.95 max = 19.73 avg = 19.10 + nanodet_m min = 6.87 max = 7.05 avg = 6.95 + yolo-fastest-1.1 min = 3.34 max = 3.44 avg = 3.39 + yolo-fastestv2 min = 3.53 max = 3.67 avg = 3.59 + vision_transformer min = 337.14 max = 343.88 avg = 338.89 + FastestDet min = 3.47 max = 3.72 avg = 3.54 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 4 0 -1 0 +loop_count = 512 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.43 max = 2.70 avg = 2.48 + squeezenet_int8 min = 2.16 max = 2.31 avg = 2.19 + mobilenet min = 2.86 max = 2.99 avg = 2.90 + mobilenet_int8 min = 2.26 max = 2.32 avg = 2.29 + mobilenet_v2 min = 3.18 max = 3.29 avg = 3.24 + mobilenet_v3 min = 2.85 max = 2.99 avg = 2.90 + shufflenet min = 2.75 max = 2.88 avg = 2.80 + shufflenet_v2 min = 2.41 max = 2.52 avg = 2.44 + mnasnet min = 3.03 max = 3.17 avg = 3.07 + proxylessnasnet min = 3.25 max = 3.38 avg = 3.29 + efficientnet_b0 min = 4.34 max = 4.44 avg = 4.37 + efficientnetv2_b0 min = 5.03 max = 5.21 avg = 5.11 + regnety_400m min = 6.33 max = 8.40 avg = 6.37 + blazeface min = 0.98 max = 1.07 avg = 1.01 + googlenet min = 7.25 max = 7.59 avg = 7.35 + googlenet_int8 min = 4.94 max = 5.06 avg = 4.98 + resnet18 min = 4.75 max = 5.22 avg = 4.91 + resnet18_int8 min = 3.19 max = 6.17 avg = 3.25 + alexnet min = 3.68 max = 3.87 avg = 3.74 + vgg16 min = 28.09 max = 29.17 avg = 28.40 + vgg16_int8 min = 21.59 max = 23.28 avg = 22.09 + resnet50 min = 12.96 max = 14.14 avg = 13.23 + resnet50_int8 min = 7.21 max = 7.45 avg = 7.29 + squeezenet_ssd min = 6.41 max = 7.84 avg = 6.49 + squeezenet_ssd_int8 min = 5.56 max = 5.71 avg = 5.60 + mobilenet_ssd min = 5.88 max = 6.01 avg = 5.93 + mobilenet_ssd_int8 min = 4.57 max = 4.84 avg = 4.61 + mobilenet_yolo min = 14.46 max = 14.78 avg = 14.59 + mobilenetv2_yolov3 min = 10.34 max = 11.77 avg = 10.45 + yolov4-tiny min = 13.64 max = 15.67 avg = 13.80 + nanodet_m min = 5.71 max = 5.87 avg = 5.78 + yolo-fastest-1.1 min = 3.00 max = 3.15 avg = 3.05 + yolo-fastestv2 min = 3.08 max = 3.22 avg = 3.14 + vision_transformer min = 174.31 max = 177.93 avg = 175.36 + FastestDet min = 3.18 max = 3.31 avg = 3.23 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 8 0 -1 0 +loop_count = 512 +num_threads = 8 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.25 max = 2.43 avg = 2.28 + squeezenet_int8 min = 2.08 max = 2.16 avg = 2.11 + mobilenet min = 2.23 max = 2.37 avg = 2.28 + mobilenet_int8 min = 1.69 max = 1.75 avg = 1.71 + mobilenet_v2 min = 2.84 max = 2.93 avg = 2.88 + mobilenet_v3 min = 2.56 max = 6.93 avg = 2.61 + shufflenet min = 2.77 max = 2.86 avg = 2.80 + shufflenet_v2 min = 2.39 max = 2.47 avg = 2.42 + mnasnet min = 2.64 max = 2.75 avg = 2.68 + proxylessnasnet min = 2.77 max = 2.88 avg = 2.82 + efficientnet_b0 min = 3.52 max = 3.64 avg = 3.56 + efficientnetv2_b0 min = 4.38 max = 4.52 avg = 4.44 + regnety_400m min = 6.22 max = 9.01 avg = 6.26 + blazeface min = 1.01 max = 1.07 avg = 1.03 + googlenet min = 6.05 max = 6.37 avg = 6.14 + googlenet_int8 min = 4.33 max = 4.47 avg = 4.37 + resnet18 min = 3.66 max = 4.22 avg = 3.73 + resnet18_int8 min = 2.74 max = 2.88 avg = 2.79 + alexnet min = 2.41 max = 2.66 avg = 2.46 + vgg16 min = 18.34 max = 19.57 avg = 18.77 + vgg16_int8 min = 13.51 max = 14.27 avg = 13.77 + resnet50 min = 8.97 max = 9.43 avg = 9.13 + resnet50_int8 min = 5.30 max = 5.52 avg = 5.36 + squeezenet_ssd min = 5.68 max = 8.24 avg = 5.75 + squeezenet_ssd_int8 min = 5.08 max = 5.19 avg = 5.12 + mobilenet_ssd min = 4.57 max = 4.87 avg = 4.64 + mobilenet_ssd_int8 min = 3.45 max = 3.56 avg = 3.49 + mobilenet_yolo min = 11.58 max = 12.11 avg = 11.76 + mobilenetv2_yolov3 min = 8.80 max = 9.10 avg = 8.88 + yolov4-tiny min = 12.09 max = 12.81 avg = 12.25 + nanodet_m min = 5.54 max = 5.70 avg = 5.60 + yolo-fastest-1.1 min = 2.94 max = 3.05 avg = 2.97 + yolo-fastestv2 min = 3.05 max = 3.18 avg = 3.09 + vision_transformer min = 92.37 max = 97.02 avg = 93.10 + FastestDet min = 3.14 max = 3.27 avg = 3.19 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 16 0 -1 0 +loop_count = 512 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.24 max = 2.44 avg = 2.26 + squeezenet_int8 min = 2.17 max = 2.30 avg = 2.20 + mobilenet min = 2.04 max = 2.15 avg = 2.08 + mobilenet_int8 min = 1.51 max = 3.09 avg = 1.53 + mobilenet_v2 min = 2.80 max = 2.90 avg = 2.83 + mobilenet_v3 min = 2.58 max = 2.69 avg = 2.62 + shufflenet min = 3.03 max = 3.11 avg = 3.06 + shufflenet_v2 min = 2.56 max = 2.66 avg = 2.58 + mnasnet min = 2.60 max = 2.68 avg = 2.63 + proxylessnasnet min = 2.77 max = 2.84 avg = 2.80 + efficientnet_b0 min = 3.47 max = 5.14 avg = 3.52 + efficientnetv2_b0 min = 4.55 max = 4.72 avg = 4.61 + regnety_400m min = 7.30 max = 7.49 avg = 7.37 + blazeface min = 1.12 max = 1.18 avg = 1.14 + googlenet min = 5.87 max = 8.36 avg = 5.94 + googlenet_int8 min = 4.33 max = 4.61 avg = 4.39 + resnet18 min = 3.46 max = 3.58 avg = 3.50 + resnet18_int8 min = 2.73 max = 2.84 avg = 2.77 + alexnet min = 2.09 max = 2.36 avg = 2.14 + vgg16 min = 14.36 max = 15.26 avg = 14.69 + vgg16_int8 min = 10.53 max = 11.27 avg = 10.75 + resnet50 min = 7.96 max = 8.26 avg = 8.06 + resnet50_int8 min = 4.80 max = 4.97 avg = 4.85 + squeezenet_ssd min = 5.93 max = 6.11 avg = 5.99 + squeezenet_ssd_int8 min = 5.33 max = 6.15 avg = 5.37 + mobilenet_ssd min = 4.33 max = 4.61 avg = 4.41 + mobilenet_ssd_int8 min = 3.20 max = 3.29 avg = 3.25 + mobilenet_yolo min = 12.23 max = 13.06 avg = 12.41 + mobilenetv2_yolov3 min = 8.27 max = 8.64 avg = 8.34 + yolov4-tiny min = 11.84 max = 12.90 avg = 12.01 + nanodet_m min = 5.91 max = 6.00 avg = 5.95 + yolo-fastest-1.1 min = 3.16 max = 3.24 avg = 3.20 + yolo-fastestv2 min = 3.31 max = 3.41 avg = 3.34 + vision_transformer min = 52.39 max = 54.43 avg = 52.79 + FastestDet min = 3.37 max = 3.48 avg = 3.41 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 32 0 -1 0 +loop_count = 512 +num_threads = 32 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.39 max = 5.09 avg = 2.53 + squeezenet_int8 min = 2.41 max = 2.55 avg = 2.49 + mobilenet min = 2.12 max = 2.35 avg = 2.25 + mobilenet_int8 min = 1.56 max = 1.67 avg = 1.62 + mobilenet_v2 min = 3.01 max = 3.26 avg = 3.16 + mobilenet_v3 min = 2.87 max = 3.08 avg = 3.01 + shufflenet min = 3.52 max = 3.79 avg = 3.69 + shufflenet_v2 min = 2.90 max = 3.16 avg = 3.03 + mnasnet min = 2.79 max = 3.02 avg = 2.93 + proxylessnasnet min = 2.98 max = 3.21 avg = 3.14 + efficientnet_b0 min = 3.85 max = 4.15 avg = 4.05 + efficientnetv2_b0 min = 5.28 max = 5.56 avg = 5.45 + regnety_400m min = 9.68 max = 10.41 avg = 9.81 + blazeface min = 1.33 max = 1.45 avg = 1.39 + googlenet min = 6.35 max = 6.72 avg = 6.58 + googlenet_int8 min = 4.98 max = 5.32 avg = 5.06 + resnet18 min = 3.69 max = 3.99 avg = 3.87 + resnet18_int8 min = 2.98 max = 3.20 avg = 3.12 + alexnet min = 2.10 max = 2.43 avg = 2.25 + vgg16 min = 13.61 max = 14.70 avg = 14.07 + vgg16_int8 min = 10.00 max = 11.30 avg = 10.26 + resnet50 min = 8.25 max = 8.85 avg = 8.59 + resnet50_int8 min = 5.18 max = 5.45 avg = 5.34 + squeezenet_ssd min = 6.52 max = 6.75 avg = 6.67 + squeezenet_ssd_int8 min = 6.00 max = 6.28 avg = 6.11 + mobilenet_ssd min = 4.53 max = 4.96 avg = 4.73 + mobilenet_ssd_int8 min = 3.50 max = 3.67 avg = 3.57 + mobilenet_yolo min = 14.89 max = 15.54 avg = 15.20 + mobilenetv2_yolov3 min = 8.48 max = 8.96 avg = 8.79 + yolov4-tiny min = 12.17 max = 12.86 avg = 12.51 + nanodet_m min = 6.65 max = 6.92 avg = 6.82 + yolo-fastest-1.1 min = 3.56 max = 3.76 avg = 3.68 + yolo-fastestv2 min = 3.72 max = 3.96 avg = 3.89 + vision_transformer min = 33.45 max = 36.96 avg = 33.96 + FastestDet min = 3.68 max = 3.93 avg = 3.86 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 64 0 -1 0 +loop_count = 512 +num_threads = 64 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.54 max = 2.94 avg = 2.64 + squeezenet_int8 min = 2.74 max = 4.14 avg = 2.85 + mobilenet min = 2.34 max = 2.76 avg = 2.48 + mobilenet_int8 min = 1.89 max = 2.03 avg = 1.95 + mobilenet_v2 min = 3.41 max = 3.84 avg = 3.54 + mobilenet_v3 min = 3.44 max = 4.46 avg = 3.56 + shufflenet min = 4.12 max = 20.37 avg = 4.31 + shufflenet_v2 min = 3.10 max = 3.41 avg = 3.23 + mnasnet min = 3.16 max = 4.36 avg = 3.31 + proxylessnasnet min = 3.43 max = 14.49 avg = 3.62 + efficientnet_b0 min = 4.57 max = 4.98 avg = 4.75 + efficientnetv2_b0 min = 6.06 max = 14.27 avg = 6.26 + regnety_400m min = 12.51 max = 14.28 avg = 12.79 + blazeface min = 1.56 max = 21.66 avg = 1.67 + googlenet min = 7.01 max = 9.07 avg = 7.21 + googlenet_int8 min = 5.89 max = 6.36 avg = 6.00 + resnet18 min = 3.96 max = 5.39 avg = 4.13 + resnet18_int8 min = 3.55 max = 4.86 avg = 3.70 + alexnet min = 2.27 max = 2.59 avg = 2.37 + vgg16 min = 14.62 max = 24.01 avg = 14.94 + vgg16_int8 min = 10.77 max = 13.28 avg = 11.03 + resnet50 min = 8.47 max = 10.08 avg = 8.74 + resnet50_int8 min = 6.10 max = 7.23 avg = 6.23 + squeezenet_ssd min = 6.85 max = 7.31 avg = 7.03 + squeezenet_ssd_int8 min = 6.81 max = 8.77 avg = 6.96 + mobilenet_ssd min = 5.06 max = 6.41 avg = 5.24 + mobilenet_ssd_int8 min = 4.23 max = 17.91 avg = 4.42 + mobilenet_yolo min = 18.58 max = 21.99 avg = 19.00 + mobilenetv2_yolov3 min = 9.18 max = 10.55 avg = 9.49 + yolov4-tiny min = 13.10 max = 15.28 avg = 13.45 + nanodet_m min = 6.81 max = 7.24 avg = 6.92 + yolo-fastest-1.1 min = 4.07 max = 5.13 avg = 4.22 + yolo-fastestv2 min = 4.14 max = 15.50 avg = 4.27 + vision_transformer min = 28.11 max = 39.27 avg = 28.71 + FastestDet min = 4.03 max = 5.15 avg = 4.19 +``` + +**Results (Without AVX512)** + +``` +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 1 0 -1 0 +loop_count = 512 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 7.39 max = 9.12 avg = 7.44 + squeezenet_int8 min = 5.34 max = 5.43 avg = 5.37 + mobilenet min = 12.72 max = 12.86 avg = 12.76 + mobilenet_int8 min = 8.91 max = 10.01 avg = 8.94 + mobilenet_v2 min = 8.92 max = 9.03 avg = 8.95 + mobilenet_v3 min = 6.95 max = 7.06 avg = 6.99 + shufflenet min = 4.20 max = 4.28 avg = 4.22 + shufflenet_v2 min = 4.48 max = 4.66 avg = 4.51 + mnasnet min = 8.53 max = 8.63 avg = 8.56 + proxylessnasnet min = 10.18 max = 10.37 avg = 10.21 + efficientnet_b0 min = 20.49 max = 20.68 avg = 20.52 + efficientnetv2_b0 min = 20.58 max = 21.41 avg = 21.12 + regnety_400m min = 10.63 max = 12.41 avg = 10.69 + blazeface min = 1.21 max = 1.54 avg = 1.23 + googlenet min = 28.41 max = 30.01 avg = 28.49 + googlenet_int8 min = 18.90 max = 19.09 avg = 18.96 + resnet18 min = 22.79 max = 24.59 avg = 22.88 + resnet18_int8 min = 14.09 max = 14.30 avg = 14.14 + alexnet min = 18.10 max = 20.00 avg = 18.30 + vgg16 min = 152.62 max = 160.52 avg = 154.99 + vgg16_int8 min = 100.31 max = 105.26 avg = 101.77 + resnet50 min = 63.09 max = 67.16 avg = 63.98 + resnet50_int8 min = 33.63 max = 35.40 avg = 33.82 + squeezenet_ssd min = 18.47 max = 19.55 avg = 18.53 + squeezenet_ssd_int8 min = 14.29 max = 14.75 avg = 14.37 + mobilenet_ssd min = 26.11 max = 26.29 avg = 26.18 + mobilenet_ssd_int8 min = 17.38 max = 17.51 avg = 17.42 + mobilenet_yolo min = 59.23 max = 62.77 avg = 59.46 + mobilenetv2_yolov3 min = 31.99 max = 32.26 avg = 32.10 + yolov4-tiny min = 42.77 max = 47.08 avg = 43.00 + nanodet_m min = 11.31 max = 11.55 avg = 11.38 + yolo-fastest-1.1 min = 4.68 max = 4.75 avg = 4.70 + yolo-fastestv2 min = 4.27 max = 4.36 avg = 4.30 + vision_transformer min = 685.59 max = 695.49 avg = 689.80 + FastestDet min = 4.58 max = 4.67 avg = 4.61 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 2 0 -1 0 +loop_count = 512 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 4.59 max = 4.70 avg = 4.63 + squeezenet_int8 min = 3.37 max = 3.53 avg = 3.42 + mobilenet min = 6.95 max = 7.07 avg = 6.98 + mobilenet_int8 min = 4.82 max = 4.98 avg = 4.86 + mobilenet_v2 min = 5.63 max = 5.76 avg = 5.69 + mobilenet_v3 min = 4.93 max = 5.09 avg = 5.00 + shufflenet min = 4.18 max = 4.25 avg = 4.21 + shufflenet_v2 min = 3.57 max = 3.66 avg = 3.61 + mnasnet min = 5.50 max = 5.66 avg = 5.55 + proxylessnasnet min = 6.38 max = 7.48 avg = 6.44 + efficientnet_b0 min = 11.56 max = 11.70 avg = 11.62 + efficientnetv2_b0 min = 12.13 max = 12.85 avg = 12.61 + regnety_400m min = 8.68 max = 11.43 avg = 8.73 + blazeface min = 1.12 max = 1.17 avg = 1.14 + googlenet min = 16.83 max = 18.62 avg = 16.97 + googlenet_int8 min = 10.89 max = 11.12 avg = 10.96 + resnet18 min = 12.54 max = 13.19 avg = 12.62 + resnet18_int8 min = 7.54 max = 7.69 avg = 7.59 + alexnet min = 9.83 max = 10.42 avg = 9.94 + vgg16 min = 77.39 max = 80.58 avg = 78.17 + vgg16_int8 min = 53.09 max = 56.45 avg = 54.12 + resnet50 min = 34.44 max = 36.32 avg = 34.88 + resnet50_int8 min = 18.03 max = 18.64 avg = 18.17 + squeezenet_ssd min = 11.69 max = 11.96 avg = 11.81 + squeezenet_ssd_int8 min = 8.92 max = 9.16 avg = 9.00 + mobilenet_ssd min = 14.55 max = 14.74 avg = 14.63 + mobilenet_ssd_int8 min = 9.64 max = 9.80 avg = 9.70 + mobilenet_yolo min = 33.79 max = 34.35 avg = 33.90 + mobilenetv2_yolov3 min = 19.21 max = 20.28 avg = 19.32 + yolov4-tiny min = 26.44 max = 28.62 avg = 26.64 + nanodet_m min = 8.64 max = 8.82 avg = 8.73 + yolo-fastest-1.1 min = 4.12 max = 4.30 avg = 4.18 + yolo-fastestv2 min = 4.10 max = 4.19 avg = 4.13 + vision_transformer min = 348.72 max = 355.93 avg = 351.27 + FastestDet min = 4.17 max = 4.28 avg = 4.21 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 4 0 -1 0 +loop_count = 512 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.16 max = 3.45 avg = 3.21 + squeezenet_int8 min = 2.38 max = 2.52 avg = 2.41 + mobilenet min = 4.15 max = 4.30 avg = 4.17 + mobilenet_int8 min = 2.71 max = 2.81 avg = 2.76 + mobilenet_v2 min = 3.86 max = 5.24 avg = 3.93 + mobilenet_v3 min = 3.46 max = 3.55 avg = 3.50 + shufflenet min = 3.20 max = 3.36 avg = 3.25 + shufflenet_v2 min = 2.70 max = 2.78 avg = 2.73 + mnasnet min = 3.79 max = 3.90 avg = 3.84 + proxylessnasnet min = 4.19 max = 7.85 avg = 4.25 + efficientnet_b0 min = 7.07 max = 7.21 avg = 7.12 + efficientnetv2_b0 min = 7.61 max = 8.23 avg = 7.70 + regnety_400m min = 7.31 max = 9.70 avg = 7.36 + blazeface min = 0.97 max = 1.03 avg = 0.98 + googlenet min = 10.37 max = 10.74 avg = 10.49 + googlenet_int8 min = 6.66 max = 6.89 avg = 6.73 + resnet18 min = 7.03 max = 7.50 avg = 7.16 + resnet18_int8 min = 4.34 max = 4.53 avg = 4.40 + alexnet min = 5.61 max = 5.92 avg = 5.69 + vgg16 min = 40.95 max = 42.55 avg = 41.50 + vgg16_int8 min = 27.20 max = 28.98 avg = 27.76 + resnet50 min = 19.48 max = 20.39 avg = 19.70 + resnet50_int8 min = 10.05 max = 11.49 avg = 10.16 + squeezenet_ssd min = 7.80 max = 8.04 avg = 7.91 + squeezenet_ssd_int8 min = 6.10 max = 6.24 avg = 6.16 + mobilenet_ssd min = 8.41 max = 9.49 avg = 8.48 + mobilenet_ssd_int8 min = 5.51 max = 5.70 avg = 5.57 + mobilenet_yolo min = 20.48 max = 22.26 avg = 20.64 + mobilenetv2_yolov3 min = 12.65 max = 13.10 avg = 12.73 + yolov4-tiny min = 16.93 max = 17.57 avg = 17.11 + nanodet_m min = 6.55 max = 6.72 avg = 6.63 + yolo-fastest-1.1 min = 3.51 max = 3.62 avg = 3.55 + yolo-fastestv2 min = 3.17 max = 3.30 avg = 3.22 + vision_transformer min = 190.82 max = 196.80 avg = 192.54 + FastestDet min = 3.41 max = 3.54 avg = 3.48 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 8 0 -1 0 +loop_count = 512 +num_threads = 8 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.58 max = 2.97 avg = 2.64 + squeezenet_int8 min = 2.12 max = 2.24 avg = 2.15 + mobilenet min = 2.81 max = 2.96 avg = 2.88 + mobilenet_int8 min = 1.84 max = 1.93 avg = 1.86 + mobilenet_v2 min = 3.29 max = 3.42 avg = 3.35 + mobilenet_v3 min = 2.83 max = 2.93 avg = 2.88 + shufflenet min = 2.99 max = 3.10 avg = 3.04 + shufflenet_v2 min = 2.46 max = 2.56 avg = 2.50 + mnasnet min = 3.06 max = 3.16 avg = 3.11 + proxylessnasnet min = 3.31 max = 3.44 avg = 3.37 + efficientnet_b0 min = 5.27 max = 5.41 avg = 5.32 + efficientnetv2_b0 min = 5.78 max = 6.40 avg = 6.12 + regnety_400m min = 6.70 max = 6.84 avg = 6.76 + blazeface min = 0.97 max = 1.03 avg = 0.99 + googlenet min = 7.56 max = 7.83 avg = 7.65 + googlenet_int8 min = 5.09 max = 5.25 avg = 5.14 + resnet18 min = 4.51 max = 4.82 avg = 4.59 + resnet18_int8 min = 2.97 max = 3.12 avg = 3.03 + alexnet min = 3.29 max = 3.42 avg = 3.33 + vgg16 min = 23.18 max = 25.05 avg = 23.61 + vgg16_int8 min = 15.52 max = 16.51 avg = 15.94 + resnet50 min = 12.47 max = 13.27 avg = 12.68 + resnet50_int8 min = 6.63 max = 7.84 avg = 6.71 + squeezenet_ssd min = 6.38 max = 7.53 avg = 6.49 + squeezenet_ssd_int8 min = 5.16 max = 6.05 avg = 5.20 + mobilenet_ssd min = 5.69 max = 5.96 avg = 5.77 + mobilenet_ssd_int8 min = 3.83 max = 3.94 avg = 3.86 + mobilenet_yolo min = 14.53 max = 15.05 avg = 14.65 + mobilenetv2_yolov3 min = 9.74 max = 10.01 avg = 9.85 + yolov4-tiny min = 13.02 max = 13.70 avg = 13.32 + nanodet_m min = 6.05 max = 6.21 avg = 6.12 + yolo-fastest-1.1 min = 3.48 max = 3.56 avg = 3.51 + yolo-fastestv2 min = 3.20 max = 3.31 avg = 3.23 + vision_transformer min = 103.59 max = 110.48 avg = 104.85 + FastestDet min = 3.34 max = 3.48 avg = 3.39 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 16 0 -1 0 +loop_count = 512 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.47 max = 2.68 avg = 2.51 + squeezenet_int8 min = 2.16 max = 2.24 avg = 2.18 + mobilenet min = 2.31 max = 2.48 avg = 2.38 + mobilenet_int8 min = 1.51 max = 1.57 avg = 1.53 + mobilenet_v2 min = 3.14 max = 3.24 avg = 3.19 + mobilenet_v3 min = 2.72 max = 2.83 avg = 2.77 + shufflenet min = 3.22 max = 3.33 avg = 3.26 + shufflenet_v2 min = 2.54 max = 2.62 avg = 2.57 + mnasnet min = 2.85 max = 2.97 avg = 2.89 + proxylessnasnet min = 3.06 max = 3.24 avg = 3.13 + efficientnet_b0 min = 4.48 max = 4.61 avg = 4.53 + efficientnetv2_b0 min = 5.40 max = 5.92 avg = 5.74 + regnety_400m min = 7.40 max = 7.62 avg = 7.46 + blazeface min = 1.05 max = 1.12 avg = 1.08 + googlenet min = 6.88 max = 7.18 avg = 7.01 + googlenet_int8 min = 4.73 max = 7.35 avg = 4.79 + resnet18 min = 3.86 max = 4.05 avg = 3.95 + resnet18_int8 min = 2.72 max = 2.81 avg = 2.75 + alexnet min = 2.37 max = 2.56 avg = 2.42 + vgg16 min = 16.66 max = 18.16 avg = 17.35 + vgg16_int8 min = 10.44 max = 11.23 avg = 10.71 + resnet50 min = 9.46 max = 10.18 avg = 9.59 + resnet50_int8 min = 5.30 max = 5.56 avg = 5.36 + squeezenet_ssd min = 5.94 max = 6.88 avg = 6.03 + squeezenet_ssd_int8 min = 5.01 max = 5.15 avg = 5.08 + mobilenet_ssd min = 4.68 max = 4.88 avg = 4.75 + mobilenet_ssd_int8 min = 3.25 max = 3.35 avg = 3.28 + mobilenet_yolo min = 13.27 max = 13.79 avg = 13.44 + mobilenetv2_yolov3 min = 8.68 max = 8.92 avg = 8.77 + yolov4-tiny min = 11.75 max = 14.46 avg = 11.98 + nanodet_m min = 6.02 max = 6.19 avg = 6.09 + yolo-fastest-1.1 min = 3.58 max = 3.75 avg = 3.62 + yolo-fastestv2 min = 3.34 max = 3.44 avg = 3.38 + vision_transformer min = 53.46 max = 57.18 avg = 54.30 + FastestDet min = 3.51 max = 3.60 avg = 3.55 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 32 0 -1 0 +loop_count = 512 +num_threads = 32 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.62 max = 2.89 avg = 2.71 + squeezenet_int8 min = 2.40 max = 2.69 avg = 2.51 + mobilenet min = 2.29 max = 2.59 avg = 2.40 + mobilenet_int8 min = 1.53 max = 1.68 avg = 1.57 + mobilenet_v2 min = 3.36 max = 3.68 avg = 3.45 + mobilenet_v3 min = 3.01 max = 3.33 avg = 3.12 + shufflenet min = 3.75 max = 4.10 avg = 3.84 + shufflenet_v2 min = 2.85 max = 3.21 avg = 3.01 + mnasnet min = 3.06 max = 3.36 avg = 3.15 + proxylessnasnet min = 3.29 max = 3.71 avg = 3.45 + efficientnet_b0 min = 4.58 max = 5.13 avg = 4.74 + efficientnetv2_b0 min = 6.05 max = 7.00 avg = 6.52 + regnety_400m min = 9.66 max = 10.64 avg = 9.95 + blazeface min = 1.28 max = 1.42 avg = 1.35 + googlenet min = 7.06 max = 8.21 avg = 7.27 + googlenet_int8 min = 5.08 max = 5.57 avg = 5.18 + resnet18 min = 3.84 max = 4.36 avg = 4.05 + resnet18_int8 min = 2.89 max = 3.22 avg = 2.96 + alexnet min = 2.23 max = 2.53 avg = 2.37 + vgg16 min = 13.52 max = 15.15 avg = 13.97 + vgg16_int8 min = 8.69 max = 9.65 avg = 9.06 + resnet50 min = 9.18 max = 10.04 avg = 9.45 + resnet50_int8 min = 5.28 max = 5.78 avg = 5.39 + squeezenet_ssd min = 6.60 max = 7.20 avg = 6.75 + squeezenet_ssd_int8 min = 5.80 max = 6.42 avg = 5.95 + mobilenet_ssd min = 4.75 max = 5.33 avg = 4.97 + mobilenet_ssd_int8 min = 3.49 max = 3.74 avg = 3.54 + mobilenet_yolo min = 15.59 max = 16.91 avg = 16.00 + mobilenetv2_yolov3 min = 8.75 max = 9.73 avg = 9.11 + yolov4-tiny min = 11.92 max = 14.56 avg = 12.36 + nanodet_m min = 6.79 max = 7.87 avg = 6.98 + yolo-fastest-1.1 min = 4.08 max = 4.65 avg = 4.24 + yolo-fastestv2 min = 3.88 max = 4.22 avg = 3.97 + vision_transformer min = 35.82 max = 40.32 avg = 36.82 + FastestDet min = 3.91 max = 4.27 avg = 4.00 +misaki@HimiMisakiBenchmarkIntel64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 64 0 -1 0 +loop_count = 512 +num_threads = 64 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.78 max = 3.11 avg = 2.89 + squeezenet_int8 min = 2.62 max = 3.18 avg = 2.74 + mobilenet min = 2.43 max = 3.79 avg = 2.55 + mobilenet_int8 min = 1.82 max = 2.05 avg = 1.88 + mobilenet_v2 min = 3.67 max = 4.14 avg = 3.81 + mobilenet_v3 min = 3.43 max = 3.85 avg = 3.56 + shufflenet min = 4.32 max = 104.85 avg = 4.68 + shufflenet_v2 min = 3.07 max = 3.40 avg = 3.16 + mnasnet min = 3.35 max = 3.86 avg = 3.48 + proxylessnasnet min = 3.70 max = 5.31 avg = 3.84 + efficientnet_b0 min = 5.19 max = 5.85 avg = 5.40 + efficientnetv2_b0 min = 6.67 max = 8.01 avg = 6.90 + regnety_400m min = 12.43 max = 17.49 avg = 12.78 + blazeface min = 1.50 max = 1.71 avg = 1.60 + googlenet min = 7.75 max = 9.38 avg = 7.91 + googlenet_int8 min = 5.91 max = 6.53 avg = 6.02 + resnet18 min = 4.16 max = 6.89 avg = 4.38 + resnet18_int8 min = 3.30 max = 12.33 avg = 3.39 + alexnet min = 2.41 max = 2.83 avg = 2.55 + vgg16 min = 14.07 max = 15.83 avg = 14.60 + vgg16_int8 min = 9.06 max = 11.13 avg = 9.42 + resnet50 min = 9.54 max = 12.98 avg = 9.80 + resnet50_int8 min = 6.15 max = 11.96 avg = 6.30 + squeezenet_ssd min = 7.08 max = 7.69 avg = 7.23 + squeezenet_ssd_int8 min = 6.35 max = 16.96 avg = 6.53 + mobilenet_ssd min = 5.07 max = 5.72 avg = 5.34 + mobilenet_ssd_int8 min = 4.31 max = 7.11 avg = 4.42 + mobilenet_yolo min = 18.91 max = 21.99 avg = 19.37 + mobilenetv2_yolov3 min = 9.35 max = 10.96 avg = 9.61 + yolov4-tiny min = 12.82 max = 14.50 avg = 13.31 + nanodet_m min = 6.89 max = 16.39 avg = 7.12 + yolo-fastest-1.1 min = 4.64 max = 5.18 avg = 4.82 + yolo-fastestv2 min = 4.20 max = 5.53 avg = 4.36 + vision_transformer min = 34.26 max = 39.12 avg = 35.44 + FastestDet min = 4.19 max = 4.68 avg = 4.35 +``` + +### Microsoft Azure Standard D64as v7 Instance + +- Type: 64 vcpu, 256 GiB RAM +- CPU: AMD EPYC 9V45 96-Core @ 4.3GHz +- OS: Debian 12 with Kernel 6.1.0-43-cloud-amd64 and GCC 12.2.0 +- ncnn version tag: 20260113 +- ncnn configuration command + - With AVX512 + > cmake -B build -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_BENCHMARK=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_VULKAN=OFF + - Without AVX512 + > cmake -B build -DCMAKE_BUILD_TYPE=Release -DNCNN_BUILD_BENCHMARK=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_VULKAN=OFF -DNCNN_AVX512=OFF -DNCNN_AVX512VNNI=OFF + +**Results (With AVX512)** + +``` +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 1 0 -1 0 +loop_count = 512 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.05 max = 3.12 avg = 3.07 + squeezenet_int8 min = 2.11 max = 2.17 avg = 2.13 + mobilenet min = 5.36 max = 6.21 avg = 5.40 + mobilenet_int8 min = 4.61 max = 4.76 avg = 4.63 + mobilenet_v2 min = 3.70 max = 3.80 avg = 3.73 + mobilenet_v3 min = 3.15 max = 3.20 avg = 3.17 + shufflenet min = 2.06 max = 2.14 avg = 2.07 + shufflenet_v2 min = 2.20 max = 2.24 avg = 2.21 + mnasnet min = 3.75 max = 5.14 avg = 3.81 + proxylessnasnet min = 4.39 max = 4.57 avg = 4.42 + efficientnet_b0 min = 6.17 max = 6.29 avg = 6.20 + efficientnetv2_b0 min = 7.34 max = 7.55 avg = 7.40 + regnety_400m min = 5.82 max = 6.28 avg = 5.86 + blazeface min = 0.70 max = 0.72 avg = 0.70 + googlenet min = 13.67 max = 14.25 avg = 13.85 + googlenet_int8 min = 8.69 max = 8.99 avg = 8.77 + resnet18 min = 13.65 max = 14.34 avg = 13.82 + resnet18_int8 min = 8.17 max = 8.86 avg = 8.30 + alexnet min = 9.76 max = 10.56 avg = 9.94 + vgg16 min = 58.78 max = 61.97 avg = 60.11 + vgg16_int8 min = 45.04 max = 48.17 avg = 46.30 + resnet50 min = 32.00 max = 33.84 avg = 32.43 + resnet50_int8 min = 17.30 max = 18.04 avg = 17.50 + squeezenet_ssd min = 9.85 max = 10.44 avg = 9.95 + squeezenet_ssd_int8 min = 7.70 max = 8.01 avg = 7.83 + mobilenet_ssd min = 11.50 max = 11.85 avg = 11.56 + mobilenet_ssd_int8 min = 8.92 max = 9.32 avg = 8.95 + mobilenet_yolo min = 25.84 max = 27.61 avg = 25.93 + mobilenetv2_yolov3 min = 14.21 max = 14.89 avg = 14.32 + yolov4-tiny min = 22.64 max = 23.84 avg = 23.07 + nanodet_m min = 5.18 max = 5.53 avg = 5.21 + yolo-fastest-1.1 min = 2.13 max = 2.22 avg = 2.14 + yolo-fastestv2 min = 2.15 max = 2.18 avg = 2.16 + vision_transformer min = 487.55 max = 495.92 avg = 489.27 + FastestDet min = 2.20 max = 2.28 avg = 2.22 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 2 0 -1 0 +loop_count = 512 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.01 max = 3.12 avg = 3.04 + squeezenet_int8 min = 2.90 max = 3.10 avg = 2.96 + mobilenet min = 4.04 max = 4.23 avg = 4.07 + mobilenet_int8 min = 3.29 max = 3.54 avg = 3.39 + mobilenet_v2 min = 4.09 max = 4.23 avg = 4.13 + mobilenet_v3 min = 3.35 max = 3.41 avg = 3.38 + shufflenet min = 3.38 max = 3.54 avg = 3.41 + shufflenet_v2 min = 2.80 max = 2.88 avg = 2.82 + mnasnet min = 3.80 max = 3.85 avg = 3.82 + proxylessnasnet min = 4.18 max = 4.32 avg = 4.21 + efficientnet_b0 min = 5.41 max = 5.54 avg = 5.44 + efficientnetv2_b0 min = 6.73 max = 7.15 avg = 6.77 + regnety_400m min = 6.84 max = 7.07 avg = 6.88 + blazeface min = 1.01 max = 1.07 avg = 1.03 + googlenet min = 10.83 max = 11.54 avg = 10.99 + googlenet_int8 min = 7.04 max = 7.29 avg = 7.11 + resnet18 min = 9.06 max = 9.66 avg = 9.21 + resnet18_int8 min = 6.06 max = 6.51 avg = 6.21 + alexnet min = 6.28 max = 6.56 avg = 6.38 + vgg16 min = 35.92 max = 37.39 avg = 36.35 + vgg16_int8 min = 26.88 max = 28.22 avg = 27.28 + resnet50 min = 20.48 max = 21.40 avg = 20.66 + resnet50_int8 min = 13.60 max = 14.23 avg = 13.86 + squeezenet_ssd min = 9.18 max = 9.81 avg = 9.37 + squeezenet_ssd_int8 min = 7.89 max = 8.27 avg = 7.98 + mobilenet_ssd min = 8.71 max = 8.99 avg = 8.80 + mobilenet_ssd_int8 min = 6.28 max = 6.42 avg = 6.31 + mobilenet_yolo min = 19.86 max = 20.59 avg = 20.02 + mobilenetv2_yolov3 min = 13.00 max = 13.38 avg = 13.06 + yolov4-tiny min = 18.19 max = 18.93 avg = 18.37 + nanodet_m min = 6.38 max = 6.52 avg = 6.42 + yolo-fastest-1.1 min = 3.31 max = 3.53 avg = 3.34 + yolo-fastestv2 min = 3.17 max = 3.34 avg = 3.21 + vision_transformer min = 250.44 max = 261.90 avg = 251.33 + FastestDet min = 3.22 max = 3.30 avg = 3.24 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 4 0 -1 0 +loop_count = 512 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.32 max = 2.38 avg = 2.35 + squeezenet_int8 min = 2.43 max = 2.52 avg = 2.47 + mobilenet min = 2.54 max = 2.68 avg = 2.58 + mobilenet_int8 min = 2.02 max = 10.84 avg = 2.14 + mobilenet_v2 min = 3.14 max = 4.66 avg = 3.18 + mobilenet_v3 min = 2.83 max = 2.93 avg = 2.86 + shufflenet min = 3.10 max = 4.62 avg = 3.15 + shufflenet_v2 min = 2.39 max = 2.47 avg = 2.42 + mnasnet min = 2.97 max = 3.23 avg = 3.00 + proxylessnasnet min = 3.10 max = 3.25 avg = 3.13 + efficientnet_b0 min = 4.21 max = 6.42 avg = 4.26 + efficientnetv2_b0 min = 5.11 max = 5.36 avg = 5.17 + regnety_400m min = 6.71 max = 6.95 avg = 6.81 + blazeface min = 0.93 max = 1.00 avg = 0.97 + googlenet min = 7.15 max = 7.53 avg = 7.24 + googlenet_int8 min = 5.11 max = 5.41 avg = 5.22 + resnet18 min = 4.96 max = 16.14 avg = 5.08 + resnet18_int8 min = 3.34 max = 3.59 avg = 3.42 + alexnet min = 3.53 max = 3.70 avg = 3.61 + vgg16 min = 20.70 max = 22.45 avg = 20.96 + vgg16_int8 min = 16.60 max = 18.46 avg = 17.05 + resnet50 min = 11.78 max = 12.43 avg = 11.94 + resnet50_int8 min = 8.06 max = 9.63 avg = 8.17 + squeezenet_ssd min = 6.82 max = 7.60 avg = 7.00 + squeezenet_ssd_int8 min = 5.97 max = 6.29 avg = 6.10 + mobilenet_ssd min = 5.56 max = 5.77 avg = 5.62 + mobilenet_ssd_int8 min = 4.12 max = 4.22 avg = 4.15 + mobilenet_yolo min = 13.27 max = 14.16 avg = 13.50 + mobilenetv2_yolov3 min = 9.38 max = 9.79 avg = 9.43 + yolov4-tiny min = 12.67 max = 13.28 avg = 12.85 + nanodet_m min = 5.17 max = 5.44 avg = 5.23 + yolo-fastest-1.1 min = 3.35 max = 3.55 avg = 3.39 + yolo-fastestv2 min = 3.13 max = 3.22 avg = 3.17 + vision_transformer min = 131.06 max = 136.44 avg = 131.58 + FastestDet min = 3.11 max = 3.27 avg = 3.15 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 8 0 -1 0 +loop_count = 512 +num_threads = 8 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.11 max = 2.33 avg = 2.15 + squeezenet_int8 min = 2.26 max = 2.38 avg = 2.32 + mobilenet min = 2.01 max = 2.10 avg = 2.04 + mobilenet_int8 min = 1.61 max = 1.73 avg = 1.64 + mobilenet_v2 min = 2.85 max = 2.97 avg = 2.90 + mobilenet_v3 min = 2.66 max = 3.85 avg = 2.72 + shufflenet min = 3.13 max = 3.24 avg = 3.19 + shufflenet_v2 min = 2.27 max = 2.41 avg = 2.31 + mnasnet min = 2.64 max = 2.82 avg = 2.68 + proxylessnasnet min = 2.73 max = 2.82 avg = 2.78 + efficientnet_b0 min = 3.62 max = 3.75 avg = 3.67 + efficientnetv2_b0 min = 4.72 max = 5.52 avg = 4.79 + regnety_400m min = 6.61 max = 6.99 avg = 6.81 + blazeface min = 0.98 max = 1.13 avg = 1.01 + googlenet min = 6.18 max = 6.44 avg = 6.27 + googlenet_int8 min = 4.44 max = 4.61 avg = 4.50 + resnet18 min = 3.70 max = 5.01 avg = 3.77 + resnet18_int8 min = 3.08 max = 3.26 avg = 3.14 + alexnet min = 2.24 max = 2.43 avg = 2.28 + vgg16 min = 13.85 max = 15.77 avg = 14.04 + vgg16_int8 min = 10.57 max = 11.49 avg = 10.93 + resnet50 min = 8.22 max = 8.57 avg = 8.31 + resnet50_int8 min = 6.22 max = 6.57 avg = 6.35 + squeezenet_ssd min = 5.76 max = 6.11 avg = 5.85 + squeezenet_ssd_int8 min = 5.43 max = 5.72 avg = 5.52 + mobilenet_ssd min = 4.27 max = 4.91 avg = 4.33 + mobilenet_ssd_int8 min = 3.05 max = 3.23 avg = 3.10 + mobilenet_yolo min = 10.64 max = 11.34 avg = 10.87 + mobilenetv2_yolov3 min = 7.66 max = 8.00 avg = 7.73 + yolov4-tiny min = 10.77 max = 11.44 avg = 10.95 + nanodet_m min = 4.65 max = 4.83 avg = 4.70 + yolo-fastest-1.1 min = 3.33 max = 3.51 avg = 3.39 + yolo-fastestv2 min = 3.01 max = 3.15 avg = 3.06 + vision_transformer min = 70.74 max = 73.98 avg = 71.04 + FastestDet min = 3.07 max = 3.29 avg = 3.13 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 16 0 -1 0 +loop_count = 512 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.23 max = 2.47 avg = 2.28 + squeezenet_int8 min = 2.50 max = 2.62 avg = 2.56 + mobilenet min = 2.15 max = 2.37 avg = 2.23 + mobilenet_int8 min = 1.59 max = 2.80 avg = 1.65 + mobilenet_v2 min = 2.97 max = 3.12 avg = 3.04 + mobilenet_v3 min = 2.91 max = 3.07 avg = 3.00 + shufflenet min = 3.66 max = 3.86 avg = 3.73 + shufflenet_v2 min = 2.60 max = 2.79 avg = 2.68 + mnasnet min = 2.85 max = 2.98 avg = 2.91 + proxylessnasnet min = 3.01 max = 3.83 avg = 3.08 + efficientnet_b0 min = 3.93 max = 4.12 avg = 3.99 + efficientnetv2_b0 min = 5.20 max = 5.49 avg = 5.31 + regnety_400m min = 8.34 max = 8.81 avg = 8.54 + blazeface min = 1.18 max = 1.29 avg = 1.23 + googlenet min = 6.19 max = 6.81 avg = 6.28 + googlenet_int8 min = 4.82 max = 5.07 avg = 4.94 + resnet18 min = 3.62 max = 4.03 avg = 3.73 + resnet18_int8 min = 3.27 max = 3.52 avg = 3.37 + alexnet min = 1.98 max = 2.55 avg = 2.03 + vgg16 min = 13.41 max = 14.36 avg = 14.01 + vgg16_int8 min = 10.02 max = 10.36 avg = 10.13 + resnet50 min = 7.68 max = 8.25 avg = 7.86 + resnet50_int8 min = 5.88 max = 6.11 avg = 5.98 + squeezenet_ssd min = 6.18 max = 6.69 avg = 6.36 + squeezenet_ssd_int8 min = 6.01 max = 6.36 avg = 6.17 + mobilenet_ssd min = 4.43 max = 4.79 avg = 4.53 + mobilenet_ssd_int8 min = 3.27 max = 5.30 avg = 3.38 + mobilenet_yolo min = 11.74 max = 12.99 avg = 12.46 + mobilenetv2_yolov3 min = 7.56 max = 7.86 avg = 7.65 + yolov4-tiny min = 10.45 max = 11.00 avg = 10.60 + nanodet_m min = 5.20 max = 5.38 avg = 5.28 + yolo-fastest-1.1 min = 3.76 max = 3.90 avg = 3.82 + yolo-fastestv2 min = 3.43 max = 3.58 avg = 3.49 + vision_transformer min = 44.35 max = 45.94 avg = 44.59 + FastestDet min = 3.39 max = 3.56 avg = 3.46 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 32 0 -1 0 +loop_count = 512 +num_threads = 32 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.48 max = 2.75 avg = 2.59 + squeezenet_int8 min = 2.81 max = 3.03 avg = 2.90 + mobilenet min = 2.39 max = 2.84 avg = 2.48 + mobilenet_int8 min = 1.87 max = 1.99 avg = 1.92 + mobilenet_v2 min = 3.39 max = 3.60 avg = 3.47 + mobilenet_v3 min = 3.40 max = 3.66 avg = 3.53 + shufflenet min = 4.34 max = 4.58 avg = 4.45 + shufflenet_v2 min = 3.11 max = 3.32 avg = 3.21 + mnasnet min = 3.31 max = 3.54 avg = 3.38 + proxylessnasnet min = 3.49 max = 3.72 avg = 3.60 + efficientnet_b0 min = 4.66 max = 4.93 avg = 4.78 + efficientnetv2_b0 min = 6.06 max = 6.45 avg = 6.26 + regnety_400m min = 10.58 max = 11.61 avg = 11.01 + blazeface min = 1.42 max = 1.65 avg = 1.49 + googlenet min = 6.76 max = 8.08 avg = 7.07 + googlenet_int8 min = 5.60 max = 7.10 avg = 5.80 + resnet18 min = 3.58 max = 3.92 avg = 3.70 + resnet18_int8 min = 3.53 max = 3.76 avg = 3.63 + alexnet min = 2.05 max = 2.23 avg = 2.09 + vgg16 min = 12.94 max = 13.88 avg = 13.34 + vgg16_int8 min = 10.98 max = 11.59 avg = 11.30 + resnet50 min = 8.05 max = 8.42 avg = 8.20 + resnet50_int8 min = 6.35 max = 6.80 avg = 6.51 + squeezenet_ssd min = 6.58 max = 7.64 avg = 6.78 + squeezenet_ssd_int8 min = 6.53 max = 6.88 avg = 6.70 + mobilenet_ssd min = 4.73 max = 4.98 avg = 4.83 + mobilenet_ssd_int8 min = 3.81 max = 3.99 avg = 3.89 + mobilenet_yolo min = 14.36 max = 15.85 avg = 14.96 + mobilenetv2_yolov3 min = 8.07 max = 8.43 avg = 8.29 + yolov4-tiny min = 10.94 max = 11.62 avg = 11.10 + nanodet_m min = 5.81 max = 6.16 avg = 5.95 + yolo-fastest-1.1 min = 4.11 max = 4.31 avg = 4.19 + yolo-fastestv2 min = 3.71 max = 4.03 avg = 3.81 + vision_transformer min = 31.80 max = 33.94 avg = 32.34 + FastestDet min = 3.68 max = 3.89 avg = 3.76 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 64 0 -1 0 +loop_count = 512 +num_threads = 64 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.92 max = 3.32 avg = 2.99 + squeezenet_int8 min = 3.35 max = 5.85 avg = 3.41 + mobilenet min = 2.78 max = 3.13 avg = 2.82 + mobilenet_int8 min = 2.36 max = 3.38 avg = 2.40 + mobilenet_v2 min = 3.94 max = 4.25 avg = 4.02 + mobilenet_v3 min = 4.01 max = 4.27 avg = 4.08 + shufflenet min = 5.14 max = 6.33 avg = 5.23 + shufflenet_v2 min = 3.63 max = 6.99 avg = 3.71 + mnasnet min = 3.68 max = 11.08 avg = 3.76 + proxylessnasnet min = 3.94 max = 10.92 avg = 4.01 + efficientnet_b0 min = 5.33 max = 10.15 avg = 5.46 + efficientnetv2_b0 min = 7.08 max = 42.44 avg = 7.33 + regnety_400m min = 13.65 max = 19.35 avg = 13.90 + blazeface min = 1.75 max = 1.91 avg = 1.79 + googlenet min = 7.83 max = 9.25 avg = 7.91 + googlenet_int8 min = 6.59 max = 7.00 avg = 6.67 + resnet18 min = 4.29 max = 5.96 avg = 4.38 + resnet18_int8 min = 4.23 max = 4.81 avg = 4.34 + alexnet min = 2.20 max = 2.50 avg = 2.26 + vgg16 min = 14.70 max = 16.55 avg = 15.05 + vgg16_int8 min = 12.53 max = 19.73 avg = 12.83 + resnet50 min = 9.02 max = 16.46 avg = 9.26 + resnet50_int8 min = 7.42 max = 7.91 avg = 7.56 + squeezenet_ssd min = 7.53 max = 15.76 avg = 7.64 + squeezenet_ssd_int8 min = 7.87 max = 15.31 avg = 8.07 + mobilenet_ssd min = 5.38 max = 5.68 avg = 5.47 + mobilenet_ssd_int8 min = 4.54 max = 5.51 avg = 4.64 + mobilenet_yolo min = 19.93 max = 26.80 avg = 20.50 + mobilenetv2_yolov3 min = 8.97 max = 10.37 avg = 9.13 + yolov4-tiny min = 12.47 max = 20.03 avg = 12.73 + nanodet_m min = 6.73 max = 17.09 avg = 6.93 + yolo-fastest-1.1 min = 4.97 max = 5.24 avg = 5.02 + yolo-fastestv2 min = 4.68 max = 6.10 avg = 4.73 + vision_transformer min = 27.70 max = 35.29 avg = 28.02 + FastestDet min = 4.48 max = 4.81 avg = 4.52 +``` + +**Results (Without AVX512)** + +``` +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 1 0 -1 0 +loop_count = 512 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 5.23 max = 5.47 avg = 5.26 + squeezenet_int8 min = 2.91 max = 2.99 avg = 2.93 + mobilenet min = 9.79 max = 10.28 avg = 9.83 + mobilenet_int8 min = 6.12 max = 6.45 avg = 6.14 + mobilenet_v2 min = 6.14 max = 6.24 avg = 6.17 + mobilenet_v3 min = 4.89 max = 5.17 avg = 4.92 + shufflenet min = 2.92 max = 3.09 avg = 2.95 + shufflenet_v2 min = 3.21 max = 4.54 avg = 3.23 + mnasnet min = 6.19 max = 6.35 avg = 6.22 + proxylessnasnet min = 7.49 max = 8.59 avg = 7.52 + efficientnet_b0 min = 14.91 max = 15.64 avg = 14.99 + efficientnetv2_b0 min = 15.83 max = 17.45 avg = 15.99 + regnety_400m min = 8.44 max = 8.81 avg = 8.48 + blazeface min = 0.82 max = 0.85 avg = 0.83 + googlenet min = 23.07 max = 24.18 avg = 23.29 + googlenet_int8 min = 13.67 max = 14.31 avg = 13.78 + resnet18 min = 20.52 max = 21.77 avg = 20.92 + resnet18_int8 min = 11.86 max = 12.41 avg = 12.00 + alexnet min = 15.25 max = 16.89 avg = 15.56 + vgg16 min = 94.69 max = 100.51 avg = 95.59 + vgg16_int8 min = 65.95 max = 69.48 avg = 66.82 + resnet50 min = 53.29 max = 56.36 avg = 53.69 + resnet50_int8 min = 25.35 max = 27.04 avg = 25.55 + squeezenet_ssd min = 14.97 max = 15.50 avg = 15.14 + squeezenet_ssd_int8 min = 10.25 max = 10.86 avg = 10.42 + mobilenet_ssd min = 20.69 max = 21.83 avg = 20.78 + mobilenet_ssd_int8 min = 11.79 max = 12.38 avg = 11.83 + mobilenet_yolo min = 46.22 max = 48.51 avg = 46.35 + mobilenetv2_yolov3 min = 23.23 max = 24.54 avg = 23.35 + yolov4-tiny min = 35.22 max = 37.03 avg = 35.65 + nanodet_m min = 7.68 max = 7.78 avg = 7.70 + yolo-fastest-1.1 min = 2.88 max = 3.06 avg = 2.90 + yolo-fastestv2 min = 2.75 max = 2.82 avg = 2.76 + vision_transformer min = 339.55 max = 354.86 avg = 341.84 + FastestDet min = 3.09 max = 3.12 avg = 3.10 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 2 0 -1 0 +loop_count = 512 +num_threads = 2 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 3.84 max = 3.94 avg = 3.87 + squeezenet_int8 min = 3.48 max = 3.65 avg = 3.53 + mobilenet min = 6.12 max = 6.26 avg = 6.15 + mobilenet_int8 min = 4.19 max = 4.46 avg = 4.26 + mobilenet_v2 min = 5.07 max = 5.28 avg = 5.10 + mobilenet_v3 min = 4.06 max = 4.23 avg = 4.09 + shufflenet min = 3.70 max = 3.83 avg = 3.73 + shufflenet_v2 min = 3.09 max = 4.38 avg = 3.12 + mnasnet min = 4.80 max = 4.90 avg = 4.83 + proxylessnasnet min = 5.34 max = 5.57 avg = 5.37 + efficientnet_b0 min = 9.68 max = 10.01 avg = 9.73 + efficientnetv2_b0 min = 10.50 max = 11.00 avg = 10.60 + regnety_400m min = 7.90 max = 8.24 avg = 7.94 + blazeface min = 1.04 max = 1.12 avg = 1.07 + googlenet min = 15.48 max = 16.18 avg = 15.68 + googlenet_int8 min = 9.39 max = 9.84 avg = 9.48 + resnet18 min = 12.38 max = 13.10 avg = 12.55 + resnet18_int8 min = 7.77 max = 8.28 avg = 7.96 + alexnet min = 8.89 max = 9.38 avg = 9.07 + vgg16 min = 53.83 max = 56.74 avg = 54.85 + vgg16_int8 min = 37.73 max = 39.08 avg = 38.04 + resnet50 min = 31.36 max = 32.68 avg = 31.60 + resnet50_int8 min = 17.11 max = 18.15 avg = 17.52 + squeezenet_ssd min = 11.28 max = 12.05 avg = 11.52 + squeezenet_ssd_int8 min = 8.99 max = 9.47 avg = 9.14 + mobilenet_ssd min = 12.81 max = 13.46 avg = 12.92 + mobilenet_ssd_int8 min = 7.83 max = 8.23 avg = 7.89 + mobilenet_yolo min = 30.17 max = 31.36 avg = 30.34 + mobilenetv2_yolov3 min = 16.71 max = 17.32 avg = 16.79 + yolov4-tiny min = 23.81 max = 24.99 avg = 24.07 + nanodet_m min = 7.41 max = 8.34 avg = 7.51 + yolo-fastest-1.1 min = 3.55 max = 3.61 avg = 3.57 + yolo-fastestv2 min = 3.51 max = 3.64 avg = 3.56 + vision_transformer min = 177.35 max = 185.49 avg = 178.18 + FastestDet min = 3.51 max = 3.56 avg = 3.53 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 4 0 -1 0 +loop_count = 512 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.69 max = 2.82 avg = 2.71 + squeezenet_int8 min = 2.48 max = 2.65 avg = 2.54 + mobilenet min = 3.55 max = 3.72 avg = 3.58 + mobilenet_int8 min = 2.55 max = 2.79 avg = 2.63 + mobilenet_v2 min = 3.43 max = 4.53 avg = 3.46 + mobilenet_v3 min = 3.15 max = 3.25 avg = 3.18 + shufflenet min = 3.18 max = 3.33 avg = 3.22 + shufflenet_v2 min = 2.55 max = 2.68 avg = 2.59 + mnasnet min = 3.31 max = 3.38 avg = 3.34 + proxylessnasnet min = 3.65 max = 3.79 avg = 3.67 + efficientnet_b0 min = 6.03 max = 7.18 avg = 6.07 + efficientnetv2_b0 min = 6.77 max = 7.10 avg = 6.84 + regnety_400m min = 7.30 max = 7.65 avg = 7.41 + blazeface min = 0.89 max = 0.99 avg = 0.95 + googlenet min = 9.32 max = 10.68 avg = 9.42 + googlenet_int8 min = 6.27 max = 6.53 avg = 6.32 + resnet18 min = 6.93 max = 7.70 avg = 7.06 + resnet18_int8 min = 4.16 max = 4.42 avg = 4.28 + alexnet min = 4.78 max = 5.13 avg = 4.92 + vgg16 min = 30.20 max = 31.60 avg = 30.64 + vgg16_int8 min = 21.71 max = 22.91 avg = 22.04 + resnet50 min = 17.74 max = 18.49 avg = 17.92 + resnet50_int8 min = 10.23 max = 10.67 avg = 10.34 + squeezenet_ssd min = 7.60 max = 8.20 avg = 7.77 + squeezenet_ssd_int8 min = 6.29 max = 7.40 avg = 6.39 + mobilenet_ssd min = 7.37 max = 7.59 avg = 7.44 + mobilenet_ssd_int8 min = 4.80 max = 5.04 avg = 4.84 + mobilenet_yolo min = 18.93 max = 19.98 avg = 19.12 + mobilenetv2_yolov3 min = 10.72 max = 11.16 avg = 10.79 + yolov4-tiny min = 15.47 max = 16.38 avg = 15.77 + nanodet_m min = 5.81 max = 5.99 avg = 5.85 + yolo-fastest-1.1 min = 3.27 max = 3.38 avg = 3.32 + yolo-fastestv2 min = 3.13 max = 3.33 avg = 3.17 + vision_transformer min = 94.34 max = 99.51 avg = 94.77 + FastestDet min = 3.13 max = 3.28 avg = 3.17 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 8 0 -1 0 +loop_count = 512 +num_threads = 8 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.19 max = 2.34 avg = 2.22 + squeezenet_int8 min = 2.16 max = 2.24 avg = 2.20 + mobilenet min = 2.40 max = 2.57 avg = 2.44 + mobilenet_int8 min = 1.67 max = 1.74 avg = 1.70 + mobilenet_v2 min = 2.89 max = 2.98 avg = 2.93 + mobilenet_v3 min = 2.71 max = 2.83 avg = 2.75 + shufflenet min = 3.21 max = 3.36 avg = 3.26 + shufflenet_v2 min = 2.31 max = 2.41 avg = 2.35 + mnasnet min = 2.69 max = 2.84 avg = 2.73 + proxylessnasnet min = 2.94 max = 3.11 avg = 2.99 + efficientnet_b0 min = 4.77 max = 4.90 avg = 4.82 + efficientnetv2_b0 min = 5.33 max = 5.62 avg = 5.42 + regnety_400m min = 7.18 max = 7.49 avg = 7.33 + blazeface min = 0.98 max = 1.07 avg = 1.02 + googlenet min = 6.82 max = 7.08 avg = 6.91 + googlenet_int8 min = 4.98 max = 5.23 avg = 5.05 + resnet18 min = 4.63 max = 4.79 avg = 4.71 + resnet18_int8 min = 2.95 max = 3.22 avg = 3.05 + alexnet min = 2.93 max = 3.06 avg = 2.98 + vgg16 min = 17.60 max = 18.43 avg = 17.77 + vgg16_int8 min = 12.29 max = 12.75 avg = 12.37 + resnet50 min = 10.98 max = 11.53 avg = 11.10 + resnet50_int8 min = 6.76 max = 7.16 avg = 6.85 + squeezenet_ssd min = 6.18 max = 6.73 avg = 6.39 + squeezenet_ssd_int8 min = 5.24 max = 5.47 avg = 5.34 + mobilenet_ssd min = 4.96 max = 5.07 avg = 5.01 + mobilenet_ssd_int8 min = 3.42 max = 3.65 avg = 3.48 + mobilenet_yolo min = 14.16 max = 14.89 avg = 14.46 + mobilenetv2_yolov3 min = 8.26 max = 8.70 avg = 8.32 + yolov4-tiny min = 11.31 max = 11.92 avg = 11.48 + nanodet_m min = 5.19 max = 5.35 avg = 5.25 + yolo-fastest-1.1 min = 3.30 max = 3.47 avg = 3.35 + yolo-fastestv2 min = 3.13 max = 3.26 avg = 3.18 + vision_transformer min = 53.74 max = 56.87 avg = 54.04 + FastestDet min = 3.11 max = 3.23 avg = 3.17 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 16 0 -1 0 +loop_count = 512 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.20 max = 2.31 avg = 2.24 + squeezenet_int8 min = 2.27 max = 2.38 avg = 2.32 + mobilenet min = 2.09 max = 2.26 avg = 2.14 + mobilenet_int8 min = 1.53 max = 1.62 avg = 1.57 + mobilenet_v2 min = 2.85 max = 3.04 avg = 2.90 + mobilenet_v3 min = 2.74 max = 2.88 avg = 2.80 + shufflenet min = 3.52 max = 4.89 avg = 3.59 + shufflenet_v2 min = 2.45 max = 2.57 avg = 2.51 + mnasnet min = 2.62 max = 2.77 avg = 2.66 + proxylessnasnet min = 2.83 max = 2.99 avg = 2.89 + efficientnet_b0 min = 4.19 max = 4.46 avg = 4.26 + efficientnetv2_b0 min = 5.16 max = 5.51 avg = 5.26 + regnety_400m min = 7.83 max = 9.62 avg = 8.02 + blazeface min = 1.12 max = 1.26 avg = 1.17 + googlenet min = 6.33 max = 6.74 avg = 6.44 + googlenet_int8 min = 4.78 max = 5.31 avg = 4.86 + resnet18 min = 3.78 max = 3.94 avg = 3.83 + resnet18_int8 min = 2.87 max = 3.04 avg = 2.93 + alexnet min = 2.10 max = 2.27 avg = 2.14 + vgg16 min = 13.16 max = 13.64 avg = 13.37 + vgg16_int8 min = 9.61 max = 10.20 avg = 9.72 + resnet50 min = 8.60 max = 8.86 avg = 8.68 + resnet50_int8 min = 5.97 max = 6.32 avg = 6.15 + squeezenet_ssd min = 5.81 max = 6.19 avg = 6.00 + squeezenet_ssd_int8 min = 5.37 max = 6.09 avg = 5.46 + mobilenet_ssd min = 4.29 max = 4.62 avg = 4.37 + mobilenet_ssd_int8 min = 3.19 max = 3.41 avg = 3.26 + mobilenet_yolo min = 13.34 max = 14.07 avg = 13.73 + mobilenetv2_yolov3 min = 7.59 max = 7.89 avg = 7.69 + yolov4-tiny min = 10.11 max = 10.79 avg = 10.37 + nanodet_m min = 5.37 max = 5.61 avg = 5.44 + yolo-fastest-1.1 min = 3.53 max = 3.65 avg = 3.59 + yolo-fastestv2 min = 3.23 max = 3.39 avg = 3.31 + vision_transformer min = 35.67 max = 36.48 avg = 35.93 + FastestDet min = 3.27 max = 3.44 avg = 3.33 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 32 0 -1 0 +loop_count = 512 +num_threads = 32 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.47 max = 2.78 avg = 2.52 + squeezenet_int8 min = 2.68 max = 2.84 avg = 2.74 + mobilenet min = 2.31 max = 2.53 avg = 2.37 + mobilenet_int8 min = 1.71 max = 1.84 avg = 1.78 + mobilenet_v2 min = 3.31 max = 3.54 avg = 3.40 + mobilenet_v3 min = 3.38 max = 3.54 avg = 3.45 + shufflenet min = 4.42 max = 4.62 avg = 4.53 + shufflenet_v2 min = 3.12 max = 3.27 avg = 3.21 + mnasnet min = 3.08 max = 3.30 avg = 3.18 + proxylessnasnet min = 3.35 max = 3.50 avg = 3.41 + efficientnet_b0 min = 4.78 max = 5.05 avg = 4.93 + efficientnetv2_b0 min = 6.19 max = 6.57 avg = 6.42 + regnety_400m min = 11.23 max = 11.97 avg = 11.64 + blazeface min = 1.40 max = 1.54 avg = 1.47 + googlenet min = 6.85 max = 7.21 avg = 6.96 + googlenet_int8 min = 5.56 max = 5.77 avg = 5.66 + resnet18 min = 3.83 max = 4.44 avg = 3.93 + resnet18_int8 min = 3.17 max = 3.33 avg = 3.23 + alexnet min = 2.06 max = 2.25 avg = 2.11 + vgg16 min = 12.12 max = 12.82 avg = 12.39 + vgg16_int8 min = 9.33 max = 10.70 avg = 9.45 + resnet50 min = 8.45 max = 8.79 avg = 8.55 + resnet50_int8 min = 6.26 max = 6.62 avg = 6.43 + squeezenet_ssd min = 6.44 max = 6.74 avg = 6.57 + squeezenet_ssd_int8 min = 6.28 max = 6.65 avg = 6.42 + mobilenet_ssd min = 4.62 max = 5.24 avg = 4.68 + mobilenet_ssd_int8 min = 3.69 max = 3.90 avg = 3.78 + mobilenet_yolo min = 17.08 max = 18.83 avg = 18.16 + mobilenetv2_yolov3 min = 8.06 max = 8.31 avg = 8.16 + yolov4-tiny min = 10.17 max = 10.60 avg = 10.39 + nanodet_m min = 6.25 max = 6.56 avg = 6.40 + yolo-fastest-1.1 min = 4.16 max = 4.34 avg = 4.25 + yolo-fastestv2 min = 3.83 max = 4.04 avg = 3.92 + vision_transformer min = 30.37 max = 33.12 avg = 30.94 + FastestDet min = 3.85 max = 4.09 avg = 3.95 +misaki@HimiMisakiBenchmarkAMD64:~/Workspace/ncnn/build/benchmark$ ./benchncnn 512 64 0 -1 0 +loop_count = 512 +num_threads = 64 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.86 max = 4.39 avg = 2.92 + squeezenet_int8 min = 3.12 max = 3.42 avg = 3.20 + mobilenet min = 2.57 max = 2.92 avg = 2.64 + mobilenet_int8 min = 2.11 max = 2.34 avg = 2.17 + mobilenet_v2 min = 3.77 max = 5.27 avg = 3.89 + mobilenet_v3 min = 3.88 max = 5.25 avg = 3.98 + shufflenet min = 5.05 max = 13.10 avg = 5.23 + shufflenet_v2 min = 3.61 max = 5.79 avg = 3.70 + mnasnet min = 3.58 max = 3.78 avg = 3.64 + proxylessnasnet min = 3.85 max = 4.10 avg = 3.93 + efficientnet_b0 min = 5.57 max = 12.98 avg = 5.69 + efficientnetv2_b0 min = 7.09 max = 14.16 avg = 7.27 + regnety_400m min = 13.09 max = 21.23 avg = 13.42 + blazeface min = 1.64 max = 1.85 avg = 1.70 + googlenet min = 7.39 max = 7.92 avg = 7.55 + googlenet_int8 min = 6.38 max = 8.56 avg = 6.46 + resnet18 min = 4.00 max = 5.19 avg = 4.11 + resnet18_int8 min = 3.63 max = 4.96 avg = 3.72 + alexnet min = 2.15 max = 2.42 avg = 2.18 + vgg16 min = 13.49 max = 21.20 avg = 13.89 + vgg16_int8 min = 10.29 max = 18.64 avg = 10.48 + resnet50 min = 9.04 max = 14.31 avg = 9.18 + resnet50_int8 min = 7.24 max = 16.60 avg = 7.46 + squeezenet_ssd min = 7.18 max = 8.72 avg = 7.28 + squeezenet_ssd_int8 min = 7.17 max = 8.81 avg = 7.31 + mobilenet_ssd min = 4.97 max = 5.35 avg = 5.04 + mobilenet_ssd_int8 min = 4.39 max = 11.82 avg = 4.48 + mobilenet_yolo min = 19.90 max = 23.57 avg = 20.68 + mobilenetv2_yolov3 min = 8.64 max = 12.98 avg = 8.88 + yolov4-tiny min = 11.37 max = 13.51 avg = 11.57 + nanodet_m min = 7.00 max = 14.01 avg = 7.09 + yolo-fastest-1.1 min = 4.73 max = 11.79 avg = 4.87 + yolo-fastestv2 min = 4.46 max = 11.84 avg = 4.54 + vision_transformer min = 29.36 max = 36.71 avg = 29.64 + FastestDet min = 4.29 max = 5.33 avg = 4.36 +``` From e10d7c8caf3b30b4f33e97cbfa47ab5ffc01557a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=A8=E7=9A=84=E5=BD=B7=E5=BE=A8?= <56149058+futz12@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:21:56 +0800 Subject: [PATCH 29/36] update docs for new convertmodel website (#6617) --- .github/ISSUE_TEMPLATE/model-convert.md | 2 +- docs/how-to-use-and-FAQ/use-ncnn-with-pytorch-or-onnx.md | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/model-convert.md b/.github/ISSUE_TEMPLATE/model-convert.md index 132ab869460..b65df375341 100644 --- a/.github/ISSUE_TEMPLATE/model-convert.md +++ b/.github/ISSUE_TEMPLATE/model-convert.md @@ -1,6 +1,6 @@ --- name: "\U0001F6B8 model convert issue" -about: "Life is Short, Use pnnx and convertmodel.com" +about: "Life is Short, Use pnnx and pnnx.pchar.cn" --- ## error log | 日志或报错信息 | ログ diff --git a/docs/how-to-use-and-FAQ/use-ncnn-with-pytorch-or-onnx.md b/docs/how-to-use-and-FAQ/use-ncnn-with-pytorch-or-onnx.md index 5b4a7f961ff..96e3a956441 100644 --- a/docs/how-to-use-and-FAQ/use-ncnn-with-pytorch-or-onnx.md +++ b/docs/how-to-use-and-FAQ/use-ncnn-with-pytorch-or-onnx.md @@ -124,6 +124,8 @@ For users who already have an `.onnx` file, please use pnnx for conversion. * **Method 2 (Alternative):** For non-Python environments or to use a standalone program, you can download the latest executable from the [pnnx Releases page](https://github.com/pnnx/pnnx/releases). +* **Method 3 (Convert Everywhere with GUI):** To convert models to ncnn anytime and anywhere, you can use [pnnx.js](https://pnnx.pchar.cn). To protect your data security, the website is built based on wasm64 technology, and your models will not be uploaded to the server. + ### 2. Run the Command-Line Conversion Open a terminal, navigate to the directory containing your model file, and run the following command. From a6a04ea238de020a4bc5cc675deb98cc3995401a Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 31 Mar 2026 15:29:45 +0800 Subject: [PATCH 30/36] gemm x86 support out_elemtype, multiheadattention and sdpa x86 support bf16 storage, skip mha bf16 tests (#6623) --- src/layer/x86/gemm_bf16s.h | 5533 ++++++++++++++-------- src/layer/x86/gemm_x86.cpp | 26 +- src/layer/x86/multiheadattention_x86.cpp | 31 +- src/layer/x86/sdpa_x86.cpp | 49 +- tests/test_gemm_2e.cpp | 6 +- tests/test_gemm_5.cpp | 327 ++ tests/testutil.cpp | 66 +- 7 files changed, 4120 insertions(+), 1918 deletions(-) create mode 100644 tests/test_gemm_5.cpp diff --git a/src/layer/x86/gemm_bf16s.h b/src/layer/x86/gemm_bf16s.h index 4f71ac5ca17..04f6a5eb12b 100644 --- a/src/layer/x86/gemm_bf16s.h +++ b/src/layer/x86/gemm_bf16s.h @@ -3660,11 +3660,11 @@ static void gemm_transB_packed_tile_bf16s(const Mat& AT_tile, const Mat& BT_tile } } -static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, float alpha, float beta, int output_transpose) +static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, float alpha, float beta, int output_transpose, int output_elemtype) { - // NCNN_LOGE("unpack_output_tile_fp32_to_bf16 %d %d %d %d", i, max_ii, j, max_jj); const int out_elempack = top_blob.elempack; const size_t out_hstep = top_blob.dims == 3 ? top_blob.cstep : (size_t)top_blob.w; + // NCNN_LOGE("unpack_output_tile_fp32_to_bf16 %d %d %d %d @ %d", i, max_ii, j, max_jj, out_elempack); const size_t c_hstep = C.dims == 3 ? C.cstep : (size_t)C.w; const int c_elempack = C.elempack; @@ -3679,13 +3679,16 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& for (; ii + 15 < max_ii; ii += 16) { unsigned short* p0; + float* p0f; if (output_transpose) { p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + p0f = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; } else { p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + p0f = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; } __m512 _c0 = _mm512_set1_ps(0.f); @@ -4119,328 +4122,495 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _ff = _mm512_mul_ps(_ff, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - __m256i _bf2 = float2bfloat_avx512(_f2); - __m256i _bf3 = float2bfloat_avx512(_f3); - __m256i _bf4 = float2bfloat_avx512(_f4); - __m256i _bf5 = float2bfloat_avx512(_f5); - __m256i _bf6 = float2bfloat_avx512(_f6); - __m256i _bf7 = float2bfloat_avx512(_f7); - __m256i _bf8 = float2bfloat_avx512(_f8); - __m256i _bf9 = float2bfloat_avx512(_f9); - __m256i _bfa = float2bfloat_avx512(_fa); - __m256i _bfb = float2bfloat_avx512(_fb); - __m256i _bfc = float2bfloat_avx512(_fc); - __m256i _bfd = float2bfloat_avx512(_fd); - __m256i _bfe = float2bfloat_avx512(_fe); - __m256i _bff = float2bfloat_avx512(_ff); - - // store bf16 - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 16) + // store fp32 + if (output_transpose) { - transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); - - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); - } - if (out_elempack == 8) - { - transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - transpose16x8_epi16(_bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); - - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf8, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf8, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf9, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf9, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bfa, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bfa, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bfb, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bfb, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 8), _mm256_extractf128_si256(_bfc, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 9), _mm256_extractf128_si256(_bfc, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 10), _mm256_extractf128_si256(_bfd, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 11), _mm256_extractf128_si256(_bfd, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 12), _mm256_extractf128_si256(_bfe, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 13), _mm256_extractf128_si256(_bfe, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 14), _mm256_extractf128_si256(_bff, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 15), _mm256_extractf128_si256(_bff, 1)); - } - if (out_elempack == 4) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); - transpose16x4_epi16(_bf8, _bf9, _bfa, _bfb); - transpose16x4_epi16(_bfc, _bfd, _bfe, _bff); - - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + 24), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p0 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p0 + 40), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p0 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p0 + 48), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + 56), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 8), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 16), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 24), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 32), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 40), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 48), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 56), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); - - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf8, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf8, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16), _mm256_extractf128_si256(_bf9, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 24), _mm256_extractf128_si256(_bf9, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 32), _mm256_extractf128_si256(_bfa, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 40), _mm256_extractf128_si256(_bfa, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 48), _mm256_extractf128_si256(_bfb, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 56), _mm256_extractf128_si256(_bfb, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 1))); - - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bfc, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 8), _mm256_extractf128_si256(_bfc, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 16), _mm256_extractf128_si256(_bfd, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 20), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 24), _mm256_extractf128_si256(_bfd, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 28), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 32), _mm256_extractf128_si256(_bfe, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 36), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 40), _mm256_extractf128_si256(_bfe, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 44), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 48), _mm256_extractf128_si256(_bff, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 52), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 56), _mm256_extractf128_si256(_bff, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 60), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 1))); - } - if (out_elempack == 1) + if (out_elempack == 16) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + 16 * 4, _f4); + _mm512_storeu_ps(p0f + 16 * 5, _f5); + _mm512_storeu_ps(p0f + 16 * 6, _f6); + _mm512_storeu_ps(p0f + 16 * 7, _f7); + _mm512_storeu_ps(p0f + 16 * 8, _f8); + _mm512_storeu_ps(p0f + 16 * 9, _f9); + _mm512_storeu_ps(p0f + 16 * 10, _fa); + _mm512_storeu_ps(p0f + 16 * 11, _fb); + _mm512_storeu_ps(p0f + 16 * 12, _fc); + _mm512_storeu_ps(p0f + 16 * 13, _fd); + _mm512_storeu_ps(p0f + 16 * 14, _fe); + _mm512_storeu_ps(p0f + 16 * 15, _ff); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + transpose16x8_ps(_f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + 16 * 4, _f4); + _mm512_storeu_ps(p0f + 16 * 5, _f5); + _mm512_storeu_ps(p0f + 16 * 6, _f6); + _mm512_storeu_ps(p0f + 16 * 7, _f7); + _mm512_storeu_ps(p0f + out_hstep * 8, _f8); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 2, _fa); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 3, _fb); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 4, _fc); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 5, _fd); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 6, _fe); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 7, _ff); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + transpose16x4_ps(_f8, _f9, _fa, _fb); + transpose16x4_ps(_fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0f + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0f + out_hstep * 8, _f8); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0f + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0f + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0f + out_hstep * 12, _fc); + _mm512_storeu_ps(p0f + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0f + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0f + out_hstep * 12 + 48, _ff); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 5, _f5); + _mm512_storeu_ps(p0f + out_hstep * 6, _f6); + _mm512_storeu_ps(p0f + out_hstep * 7, _f7); + _mm512_storeu_ps(p0f + out_hstep * 8, _f8); + _mm512_storeu_ps(p0f + out_hstep * 9, _f9); + _mm512_storeu_ps(p0f + out_hstep * 10, _fa); + _mm512_storeu_ps(p0f + out_hstep * 11, _fb); + _mm512_storeu_ps(p0f + out_hstep * 12, _fc); + _mm512_storeu_ps(p0f + out_hstep * 13, _fd); + _mm512_storeu_ps(p0f + out_hstep * 14, _fe); + _mm512_storeu_ps(p0f + out_hstep * 15, _ff); + } + p0f += out_hstep * 16; + } + else { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); - } - p0 += out_hstep * 16; + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + 16 * 4, _f4); + _mm512_storeu_ps(p0f + 16 * 5, _f5); + _mm512_storeu_ps(p0f + 16 * 6, _f6); + _mm512_storeu_ps(p0f + 16 * 7, _f7); + _mm512_storeu_ps(p0f + 16 * 8, _f8); + _mm512_storeu_ps(p0f + 16 * 9, _f9); + _mm512_storeu_ps(p0f + 16 * 10, _fa); + _mm512_storeu_ps(p0f + 16 * 11, _fb); + _mm512_storeu_ps(p0f + 16 * 12, _fc); + _mm512_storeu_ps(p0f + 16 * 13, _fd); + _mm512_storeu_ps(p0f + 16 * 14, _fe); + _mm512_storeu_ps(p0f + 16 * 15, _ff); + p0f += 256; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + 8, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + 8 * 2, _mm512_castps512_ps256(_f2)); + _mm256_storeu_ps(p0f + 8 * 3, _mm512_castps512_ps256(_f3)); + _mm256_storeu_ps(p0f + 8 * 4, _mm512_castps512_ps256(_f4)); + _mm256_storeu_ps(p0f + 8 * 5, _mm512_castps512_ps256(_f5)); + _mm256_storeu_ps(p0f + 8 * 6, _mm512_castps512_ps256(_f6)); + _mm256_storeu_ps(p0f + 8 * 7, _mm512_castps512_ps256(_f7)); + _mm256_storeu_ps(p0f + 8 * 8, _mm512_castps512_ps256(_f8)); + _mm256_storeu_ps(p0f + 8 * 9, _mm512_castps512_ps256(_f9)); + _mm256_storeu_ps(p0f + 8 * 10, _mm512_castps512_ps256(_fa)); + _mm256_storeu_ps(p0f + 8 * 11, _mm512_castps512_ps256(_fb)); + _mm256_storeu_ps(p0f + 8 * 12, _mm512_castps512_ps256(_fc)); + _mm256_storeu_ps(p0f + 8 * 13, _mm512_castps512_ps256(_fd)); + _mm256_storeu_ps(p0f + 8 * 14, _mm512_castps512_ps256(_fe)); + _mm256_storeu_ps(p0f + 8 * 15, _mm512_castps512_ps256(_ff)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 2, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 3, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 4, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 5, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 6, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 7, _mm512_extractf32x8_ps(_f7, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 8, _mm512_extractf32x8_ps(_f8, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 9, _mm512_extractf32x8_ps(_f9, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 10, _mm512_extractf32x8_ps(_fa, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 11, _mm512_extractf32x8_ps(_fb, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 12, _mm512_extractf32x8_ps(_fc, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 13, _mm512_extractf32x8_ps(_fd, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 14, _mm512_extractf32x8_ps(_fe, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 15, _mm512_extractf32x8_ps(_ff, 1)); + p0f += 128; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + + __m512 _tmp8 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpa = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpb = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpc = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpd = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpe = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpf = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + + _f8 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f9 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _fa = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _fb = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _fc = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _fd = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _fe = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _ff = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0f + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0f + out_hstep * 8, _f8); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0f + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0f + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0f + out_hstep * 12, _fc); + _mm512_storeu_ps(p0f + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0f + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0f + out_hstep * 12 + 48, _ff); + p0f += 64; + } + if (out_elempack == 1) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 5, _f5); + _mm512_storeu_ps(p0f + out_hstep * 6, _f6); + _mm512_storeu_ps(p0f + out_hstep * 7, _f7); + _mm512_storeu_ps(p0f + out_hstep * 8, _f8); + _mm512_storeu_ps(p0f + out_hstep * 9, _f9); + _mm512_storeu_ps(p0f + out_hstep * 10, _fa); + _mm512_storeu_ps(p0f + out_hstep * 11, _fb); + _mm512_storeu_ps(p0f + out_hstep * 12, _fc); + _mm512_storeu_ps(p0f + out_hstep * 13, _fd); + _mm512_storeu_ps(p0f + out_hstep * 14, _fe); + _mm512_storeu_ps(p0f + out_hstep * 15, _ff); + p0f += 16; + } + } } else { - if (out_elempack == 16) + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + __m256i _bf8 = float2bfloat_avx512(_f8); + __m256i _bf9 = float2bfloat_avx512(_f9); + __m256i _bfa = float2bfloat_avx512(_fa); + __m256i _bfb = float2bfloat_avx512(_fb); + __m256i _bfc = float2bfloat_avx512(_fc); + __m256i _bfd = float2bfloat_avx512(_fd); + __m256i _bfe = float2bfloat_avx512(_fe); + __m256i _bff = float2bfloat_avx512(_ff); + + // store bf16 + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); - p0 += 256; - } - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf8, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf9, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bfa, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bfb, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bfc, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bfd, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bfe, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bff, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 8), _mm256_extractf128_si256(_bf8, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 9), _mm256_extractf128_si256(_bf9, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 10), _mm256_extractf128_si256(_bfa, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 11), _mm256_extractf128_si256(_bfb, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 12), _mm256_extractf128_si256(_bfc, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 13), _mm256_extractf128_si256(_bfd, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 14), _mm256_extractf128_si256(_bfe, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 15), _mm256_extractf128_si256(_bff, 1)); - p0 += 128; - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf8, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf9, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bfa, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bfb, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bfc, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bfd, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bfe, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bff, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 8), _mm256_extractf128_si256(_bf8, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 9), _mm256_extractf128_si256(_bf9, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 10), _mm256_extractf128_si256(_bfa, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 11), _mm256_extractf128_si256(_bfb, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 12), _mm256_extractf128_si256(_bfc, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 13), _mm256_extractf128_si256(_bfd, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 14), _mm256_extractf128_si256(_bfe, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 15), _mm256_extractf128_si256(_bff, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 1))); - p0 += 64; - } - if (out_elempack == 1) - { - transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + if (out_elempack == 16) + { + transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); - p0 += 16; + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); + } + if (out_elempack == 8) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + transpose16x8_epi16(_bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 2), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 3), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 4), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 5), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 6), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16 * 7), _bff); + } + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); + transpose16x4_epi16(_bf8, _bf9, _bfa, _bfb); + transpose16x4_epi16(_bfc, _bfd, _bfe, _bff); + + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 32), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 48), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 16), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 32), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 48), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 16), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 32), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8 + 48), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12 + 16), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12 + 32), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12 + 48), _bff); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 15), _bff); + p0 += 256; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf9, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bfa, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bfb, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bfc, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bfd, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bfe, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bff, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 9), _mm256_extractf128_si256(_bf9, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 10), _mm256_extractf128_si256(_bfa, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 11), _mm256_extractf128_si256(_bfb, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 12), _mm256_extractf128_si256(_bfc, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 13), _mm256_extractf128_si256(_bfd, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 14), _mm256_extractf128_si256(_bfe, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 15), _mm256_extractf128_si256(_bff, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf8, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf9, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bfa, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bfb, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bfc, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bfd, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bfe, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bff, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 8), _mm256_extractf128_si256(_bf8, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 9), _mm256_extractf128_si256(_bf9, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 10), _mm256_extractf128_si256(_bfa, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 11), _mm256_extractf128_si256(_bfb, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 12), _mm256_extractf128_si256(_bfc, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 13), _mm256_extractf128_si256(_bfd, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 14), _mm256_extractf128_si256(_bfe, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 15), _mm256_extractf128_si256(_bff, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf8, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf9, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bfa, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bfb, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bfc, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bfd, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bfe, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bff, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x16_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7, _bf8, _bf9, _bfa, _bfb, _bfc, _bfd, _bfe, _bff); + + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 8), _bf8); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 9), _bf9); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 10), _bfa); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 11), _bfb); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 12), _bfc); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 13), _bfd); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 14), _bfe); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 15), _bff); + p0 += 16; + } } } } @@ -4695,180 +4865,283 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = _mm512_mul_ps(_f7, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - __m256i _bf2 = float2bfloat_avx512(_f2); - __m256i _bf3 = float2bfloat_avx512(_f3); - __m256i _bf4 = float2bfloat_avx512(_f4); - __m256i _bf5 = float2bfloat_avx512(_f5); - __m256i _bf6 = float2bfloat_avx512(_f6); - __m256i _bf7 = float2bfloat_avx512(_f7); - - if (output_transpose) - { - if (out_elempack == 8) - { - transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); - } - if (out_elempack == 4) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); - - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 8), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 10), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 12), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 14), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); - } - if (out_elempack == 1) + if (output_elemtype == 1) + { + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + 16 * 4, _f4); + _mm512_storeu_ps(p0f + 16 * 5, _f5); + _mm512_storeu_ps(p0f + 16 * 6, _f6); + _mm512_storeu_ps(p0f + 16 * 7, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0f + out_hstep * 4 + 48, _f7); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 5, _f5); + _mm512_storeu_ps(p0f + out_hstep * 6, _f6); + _mm512_storeu_ps(p0f + out_hstep * 7, _f7); + } + p0f += out_hstep * 8; + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + 16 * 4, _f4); + _mm512_storeu_ps(p0f + 16 * 5, _f5); + _mm512_storeu_ps(p0f + 16 * 6, _f6); + _mm512_storeu_ps(p0f + 16 * 7, _f7); + p0f += 128; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + 8, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + 8 * 2, _mm512_castps512_ps256(_f2)); + _mm256_storeu_ps(p0f + 8 * 3, _mm512_castps512_ps256(_f3)); + _mm256_storeu_ps(p0f + 8 * 4, _mm512_castps512_ps256(_f4)); + _mm256_storeu_ps(p0f + 8 * 5, _mm512_castps512_ps256(_f5)); + _mm256_storeu_ps(p0f + 8 * 6, _mm512_castps512_ps256(_f6)); + _mm256_storeu_ps(p0f + 8 * 7, _mm512_castps512_ps256(_f7)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 2, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 3, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 4, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 5, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 6, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8 * 7, _mm512_extractf32x8_ps(_f7, 1)); + p0f += 64; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm512_castps512_ps128(_f0)); + _mm_storeu_ps(p0f + 4, _mm512_castps512_ps128(_f1)); + _mm_storeu_ps(p0f + 4 * 2, _mm512_castps512_ps128(_f2)); + _mm_storeu_ps(p0f + 4 * 3, _mm512_castps512_ps128(_f3)); + _mm_storeu_ps(p0f + 4 * 4, _mm512_castps512_ps128(_f4)); + _mm_storeu_ps(p0f + 4 * 5, _mm512_castps512_ps128(_f5)); + _mm_storeu_ps(p0f + 4 * 6, _mm512_castps512_ps128(_f6)); + _mm_storeu_ps(p0f + 4 * 7, _mm512_castps512_ps128(_f7)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f0), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f1), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 2, _mm256_extractf128_ps(_mm512_castps512_ps256(_f2), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 3, _mm256_extractf128_ps(_mm512_castps512_ps256(_f3), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f4), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 5, _mm256_extractf128_ps(_mm512_castps512_ps256(_f5), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 6, _mm256_extractf128_ps(_mm512_castps512_ps256(_f6), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 7, _mm256_extractf128_ps(_mm512_castps512_ps256(_f7), 1)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f0, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f1, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 2, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f2, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 3, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f3, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 4, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f4, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 5, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f5, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 6, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f6, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4 * 7, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f7, 1))); + _mm_storeu_ps(p0f + out_hstep * 12, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f0, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f1, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 2, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f2, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 3, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f3, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 4, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f4, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 5, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f5, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 6, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f6, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4 * 7, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f7, 1), 1)); + p0f += 32; + } + if (out_elempack == 1) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + out_hstep, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + out_hstep * 2, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + out_hstep * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0f + out_hstep * 4, _mm512_castps512_ps256(_f2)); + _mm256_storeu_ps(p0f + out_hstep * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0f + out_hstep * 6, _mm512_castps512_ps256(_f3)); + _mm256_storeu_ps(p0f + out_hstep * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_castps512_ps256(_f4)); + _mm256_storeu_ps(p0f + out_hstep * 9, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0f + out_hstep * 10, _mm512_castps512_ps256(_f5)); + _mm256_storeu_ps(p0f + out_hstep * 11, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0f + out_hstep * 12, _mm512_castps512_ps256(_f6)); + _mm256_storeu_ps(p0f + out_hstep * 13, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0f + out_hstep * 14, _mm512_castps512_ps256(_f7)); + _mm256_storeu_ps(p0f + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + p0f += 8; + } } - p0 += out_hstep * 8; } else { - if (out_elempack == 16) + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); - p0 += 128; - } - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); - p0 += 64; - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); - p0 += 32; - } - if (out_elempack == 1) - { - transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); - p0 += 8; + if (out_elempack == 8) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + } + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose16x4_epi16(_bf4, _bf5, _bf6, _bf7); + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 32), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 48), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 16), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 32), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4 + 48), _bf7); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), _bf7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 4), _bf4); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 5), _bf5); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 6), _bf6); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 7), _bf7); + p0 += 128; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _mm256_extractf128_si256(_bf7, 1)); + p0 += 64; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 5), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 7), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + p0 += 32; + } + if (out_elempack == 1) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); + p0 += 8; + } } } } @@ -5024,156 +5297,281 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f3 = _mm512_mul_ps(_f3, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - __m256i _bf2 = float2bfloat_avx512(_f2); - __m256i _bf3 = float2bfloat_avx512(_f3); - - if (output_transpose) + if (output_elemtype == 1) { + if (output_transpose) + { #if !(defined(__x86_64__) || defined(_M_X64)) #if __AVX__ #if __AVX512F__ - if (out_elempack == 16) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - const int jj_m16 = jj % 16; - unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; - _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p1 + 128), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p1 + 144), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p1 + 160), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p1 + 176), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p1 + 192), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p1 + 208), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p1 + 224), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p1 + 240), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } + if (out_elempack == 16) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1f = p0f - out_hstep * jj_m16 + jj_m16; + _mm_storeu_ps(p1f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p1f + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p1f + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p1f + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p1f + 64, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p1f + 80, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p1f + 96, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p1f + 112, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p1f + 128, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p1f + 144, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p1f + 160, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p1f + 176, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p1f + 192, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p1f + 208, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p1f + 224, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p1f + 240, _mm512_extractf32x4_ps(_f3, 3)); + } #endif // __AVX512F__ - if (out_elempack == 8) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - const int jj_m8 = jj % 8; - unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; - _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p1 + 16), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p1 + 48), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p1 + 72), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p1 + 80), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p1 + 88), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p1 + 104), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p1 + 112), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p1 + 120), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } + if (out_elempack == 8) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1f = p0f - out_hstep * jj_m8 + jj_m8; + _mm_storeu_ps(p1f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p1f + 8, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p1f + 16, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p1f + 24, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p1f + 32, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p1f + 40, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p1f + 48, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p1f + 56, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p1f + 64, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p1f + 72, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p1f + 80, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p1f + 88, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p1f + 96, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p1f + 104, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p1f + 112, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p1f + 120, _mm512_extractf32x4_ps(_f3, 3)); + } #endif // __AVX__ #endif // !(defined(__x86_64__) || defined(_M_X64)) - if (out_elempack == 4) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } - if (out_elempack == 1) + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + } + p0f += out_hstep * 4; + } + else { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + p0f += 64; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(p0f, _tmp0); + _mm512_storeu_ps(p0f + 16, _tmp1); + _mm512_storeu_ps(p0f + out_hstep * 8, _tmp2); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _tmp3); + p0f += 32; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep * 4, _f1); + _mm512_storeu_ps(p0f + out_hstep * 8, _f2); + _mm512_storeu_ps(p0f + out_hstep * 12, _f3); + p0f += 16; + } + if (out_elempack == 1) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + + _mm_storeu_ps(p0f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0f + out_hstep, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 2, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0f + out_hstep * 3, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0f + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0f + out_hstep * 6, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0f + out_hstep * 7, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0f + out_hstep * 9, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0f + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0f + out_hstep * 11, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0f + out_hstep * 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0f + out_hstep * 13, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0f + out_hstep * 14, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0f + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + p0f += 4; + } } - p0 += out_hstep * 4; } else { - if (out_elempack == 16) + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); - _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); - p0 += 64; - } - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); - p0 += 32; - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - p0 += 16; +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p1 + 128), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p1 + 144), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p1 + 160), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p1 + 176), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p1 + 192), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p1 + 208), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p1 + 224), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p1 + 240), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p1 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p1 + 32), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p1 + 48), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p1 + 64), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p1 + 72), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p1 + 80), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p1 + 88), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p1 + 96), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p1 + 104), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p1 + 112), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p1 + 120), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 32), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 48), _bf3); + } + if (out_elempack == 1) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _bf3); + } + p0 += out_hstep * 4; } - if (out_elempack == 1) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - p0 += 4; + else + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 2), _bf2); + _mm256_storeu_si256((__m256i*)(p0 + 16 * 3), _bf3); + p0 += 64; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _mm256_extractf128_si256(_bf3, 1)); + p0 += 32; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 3), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 16; + } + if (out_elempack == 1) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 4; + } } } } @@ -5271,56 +5669,104 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm512_mul_ps(_f1, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - - if (output_transpose) - { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - p0 += out_hstep * 2; - } - else + if (output_elemtype == 1) { - if (out_elempack == 16) + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - p0 += 32; + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + p0f += out_hstep * 2; } - if (out_elempack == 8) + else { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); - p0 += 16; + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + p0f += 32; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + 8, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + p0f += 16; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm512_castps512_ps128(_f0)); + _mm_storeu_ps(p0f + 4, _mm512_castps512_ps128(_f1)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f0), 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f1), 1)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f0, 1))); + _mm_storeu_ps(p0f + out_hstep * 8 + 4, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f1, 1))); + _mm_storeu_ps(p0f + out_hstep * 12, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f0, 1), 1)); + _mm_storeu_ps(p0f + out_hstep * 12 + 4, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f1, 1), 1)); + p0f += 8; + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0f, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0f + 1, _vindex, _f1, sizeof(float)); + p0f += 2; + } } - if (out_elempack == 4) + } + else + { + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - p0 += 8; + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + p0 += out_hstep * 2; } - if (out_elempack == 1) + else { - transpose16x2_epi16(_bf0, _bf1); - __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); - __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); - _mm512_i32scatter_epi32(p0, _vindex, _bf01, sizeof(unsigned short)); - p0 += 2; - } - } - } - for (; jj < max_jj; jj++) - { - __m512 _f0 = _mm512_load_ps(pp); + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + p0 += 32; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + p0 += 8; + } + if (out_elempack == 1) + { + transpose16x2_epi16(_bf0, _bf1); + __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_epi32(p0, _vindex, _bf01, sizeof(unsigned short)); + p0 += 2; + } + } + } + } + for (; jj < max_jj; jj++) + { + __m512 _f0 = _mm512_load_ps(pp); pp += 16; if (pC) @@ -5377,56 +5823,112 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); } - __m256i _bf0 = float2bfloat_avx512(_f0); - - if (output_transpose) + if (output_elemtype == 1) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - p0 += out_hstep; + if (output_transpose) + { + _mm512_storeu_ps(p0f, _f0); + p0f += out_hstep; + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + p0f += 16; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + p0f += 8; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm512_castps512_ps128(_f0)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f0), 1)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f0, 1))); + _mm_storeu_ps(p0f + out_hstep * 12, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f0, 1), 1)); + p0f += 4; + } + if (out_elempack == 1) + { + float tmp[16]; + _mm512_storeu_ps(tmp, _f0); + + p0f[0] = tmp[0]; + p0f[out_hstep] = tmp[1]; + p0f[out_hstep * 2] = tmp[2]; + p0f[out_hstep * 3] = tmp[3]; + p0f[out_hstep * 4] = tmp[4]; + p0f[out_hstep * 5] = tmp[5]; + p0f[out_hstep * 6] = tmp[6]; + p0f[out_hstep * 7] = tmp[7]; + p0f[out_hstep * 8] = tmp[8]; + p0f[out_hstep * 9] = tmp[9]; + p0f[out_hstep * 10] = tmp[10]; + p0f[out_hstep * 11] = tmp[11]; + p0f[out_hstep * 12] = tmp[12]; + p0f[out_hstep * 13] = tmp[13]; + p0f[out_hstep * 14] = tmp[14]; + p0f[out_hstep * 15] = tmp[15]; + p0f++; + } + } } else { - if (out_elempack == 16) + __m256i _bf0 = float2bfloat_avx512(_f0); + + if (output_transpose) { _mm256_storeu_si256((__m256i*)p0, _bf0); - p0 += 16; - } - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - p0 += 8; + p0 += out_hstep; } - if (out_elempack == 4) + else { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - p0 += 4; - } - if (out_elempack == 1) - { - unsigned short tmp[16]; - _mm256_storeu_si256((__m256i*)tmp, _bf0); - - p0[0] = tmp[0]; - p0[out_hstep] = tmp[1]; - p0[out_hstep * 2] = tmp[2]; - p0[out_hstep * 3] = tmp[3]; - p0[out_hstep * 4] = tmp[4]; - p0[out_hstep * 5] = tmp[5]; - p0[out_hstep * 6] = tmp[6]; - p0[out_hstep * 7] = tmp[7]; - p0[out_hstep * 8] = tmp[8]; - p0[out_hstep * 9] = tmp[9]; - p0[out_hstep * 10] = tmp[10]; - p0[out_hstep * 11] = tmp[11]; - p0[out_hstep * 12] = tmp[12]; - p0[out_hstep * 13] = tmp[13]; - p0[out_hstep * 14] = tmp[14]; - p0[out_hstep * 15] = tmp[15]; - p0++; + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + p0 += 16; + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short tmp[16]; + _mm256_storeu_si256((__m256i*)tmp, _bf0); + + p0[0] = tmp[0]; + p0[out_hstep] = tmp[1]; + p0[out_hstep * 2] = tmp[2]; + p0[out_hstep * 3] = tmp[3]; + p0[out_hstep * 4] = tmp[4]; + p0[out_hstep * 5] = tmp[5]; + p0[out_hstep * 6] = tmp[6]; + p0[out_hstep * 7] = tmp[7]; + p0[out_hstep * 8] = tmp[8]; + p0[out_hstep * 9] = tmp[9]; + p0[out_hstep * 10] = tmp[10]; + p0[out_hstep * 11] = tmp[11]; + p0[out_hstep * 12] = tmp[12]; + p0[out_hstep * 13] = tmp[13]; + p0[out_hstep * 14] = tmp[14]; + p0[out_hstep * 15] = tmp[15]; + p0++; + } } } } @@ -5435,13 +5937,16 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& for (; ii + 7 < max_ii; ii += 8) { unsigned short* p0; + float* p0f; if (output_transpose) { p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + p0f = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; } else { p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + p0f = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; } __m256 _c0 = _mm256_set1_ps(0.f); @@ -5761,256 +6266,434 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = _mm512_mul_ps(_f7, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - __m256i _bf2 = float2bfloat_avx512(_f2); - __m256i _bf3 = float2bfloat_avx512(_f3); - __m256i _bf4 = float2bfloat_avx512(_f4); - __m256i _bf5 = float2bfloat_avx512(_f5); - __m256i _bf6 = float2bfloat_avx512(_f6); - __m256i _bf7 = float2bfloat_avx512(_f7); - - if (output_transpose) - { - if (out_elempack == 16) - { - transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - - _mm_store_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_store_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf4, 0)); - _mm_store_si128((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); - _mm_store_si128((__m128i*)(p0 + 16 + 8), _mm256_extractf128_si256(_bf4, 1)); - _mm_store_si128((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf1, 0)); - _mm_store_si128((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf5, 0)); - _mm_store_si128((__m128i*)(p0 + 48), _mm256_extractf128_si256(_bf1, 1)); - _mm_store_si128((__m128i*)(p0 + 48 + 8), _mm256_extractf128_si256(_bf5, 1)); - _mm_store_si128((__m128i*)(p0 + 64), _mm256_extractf128_si256(_bf2, 0)); - _mm_store_si128((__m128i*)(p0 + 64 + 8), _mm256_extractf128_si256(_bf6, 0)); - _mm_store_si128((__m128i*)(p0 + 80), _mm256_extractf128_si256(_bf2, 1)); - _mm_store_si128((__m128i*)(p0 + 80 + 8), _mm256_extractf128_si256(_bf6, 1)); - _mm_store_si128((__m128i*)(p0 + 96), _mm256_extractf128_si256(_bf3, 0)); - _mm_store_si128((__m128i*)(p0 + 96 + 8), _mm256_extractf128_si256(_bf7, 0)); - _mm_store_si128((__m128i*)(p0 + 112), _mm256_extractf128_si256(_bf3, 1)); - _mm_store_si128((__m128i*)(p0 + 112 + 8), _mm256_extractf128_si256(_bf7, 1)); - } - if (out_elempack == 8) - { - __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); - __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); - __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); - __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); - __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); - __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); - __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); - __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); - transpose8x8_epi16(_bf0l, _bf1l, _bf2l, _bf3l, _bf4l, _bf5l, _bf6l, _bf7l); - _mm_storeu_si128((__m128i*)p0, _bf0l); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6l); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7l); - __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); - __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); - __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); - __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); - __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); - __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); - __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); - __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); - transpose8x8_epi16(_bf0h, _bf1h, _bf2h, _bf3h, _bf4h, _bf5h, _bf6h, _bf7h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _bf0h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _bf1h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _bf2h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _bf3h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _bf4h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _bf5h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _bf6h); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _bf7h); - } - if (out_elempack == 4) - { - __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); - __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); - __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); - __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); - __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); - __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); - __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); - __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); - transpose8x4_epi16(_bf0l, _bf1l, _bf2l, _bf3l); - transpose8x4_epi16(_bf4l, _bf5l, _bf6l, _bf7l); - _mm_storel_epi64((__m128i*)p0, _bf0l); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf0l)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf1l); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf1l)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf2l); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf2l)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf3l); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf3l)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf4l); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf4l)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _bf5l); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf5l)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _bf6l); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf6l)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _bf7l); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7l)); - __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); - __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); - __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); - __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); - __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); - __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); - __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); - __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); - transpose8x4_epi16(_bf0h, _bf1h, _bf2h, _bf3h); - transpose8x4_epi16(_bf4h, _bf5h, _bf6h, _bf7h); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _bf0h); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_bf0h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _bf1h); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_bf1h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _bf2h); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 5), _mm_castsi128_pd(_bf2h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _bf3h); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 7), _mm_castsi128_pd(_bf3h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _bf4h); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_bf4h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _bf5h); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_bf5h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 4), _bf6h); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_bf6h)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 6), _bf7h); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_bf7h)); - } - if (out_elempack == 1) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); - } - p0 += out_hstep * 16; + if (output_elemtype == 1) + { + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm256_store_ps(p0f, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0f + 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_store_ps(p0f + 16, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0f + 16 + 8, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_store_ps(p0f + 16 * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0f + 16 * 2 + 8, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_store_ps(p0f + 16 * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_store_ps(p0f + 16 * 3 + 8, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_store_ps(p0f + 16 * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_store_ps(p0f + 16 * 4 + 8, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_store_ps(p0f + 16 * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_store_ps(p0f + 16 * 5 + 8, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_store_ps(p0f + 16 * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_store_ps(p0f + 16 * 6 + 8, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_store_ps(p0f + 16 * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_store_ps(p0f + 16 * 7 + 8, _mm512_extractf32x8_ps(_f7, 1)); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 16 * 2, _f2); + _mm512_storeu_ps(p0f + 16 * 3, _f3); + _mm512_storeu_ps(p0f + out_hstep * 8, _f4); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 2, _f6); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16 * 3, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 8, _f2); + _mm512_storeu_ps(p0f + out_hstep * 8 + 16, _f3); + _mm512_storeu_ps(p0f + out_hstep * 12, _f6); + _mm512_storeu_ps(p0f + out_hstep * 12 + 16, _f7); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + out_hstep, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + out_hstep * 2, _mm512_castps512_ps256(_f2)); + _mm256_storeu_ps(p0f + out_hstep * 3, _mm512_castps512_ps256(_f3)); + _mm256_storeu_ps(p0f + out_hstep * 4, _mm512_castps512_ps256(_f4)); + _mm256_storeu_ps(p0f + out_hstep * 5, _mm512_castps512_ps256(_f5)); + _mm256_storeu_ps(p0f + out_hstep * 6, _mm512_castps512_ps256(_f6)); + _mm256_storeu_ps(p0f + out_hstep * 7, _mm512_castps512_ps256(_f7)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + out_hstep * 9, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0f + out_hstep * 10, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0f + out_hstep * 11, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0f + out_hstep * 12, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0f + out_hstep * 13, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0f + out_hstep * 14, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0f + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + } + p0f += out_hstep * 16; + } + else + { + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + 8, _mm512_castps512_ps256(_f1)); + _mm256_storeu_ps(p0f + 8 * 2, _mm512_castps512_ps256(_f2)); + _mm256_storeu_ps(p0f + 8 * 3, _mm512_castps512_ps256(_f3)); + _mm256_storeu_ps(p0f + 8 * 4, _mm512_castps512_ps256(_f4)); + _mm256_storeu_ps(p0f + 8 * 5, _mm512_castps512_ps256(_f5)); + _mm256_storeu_ps(p0f + 8 * 6, _mm512_castps512_ps256(_f6)); + _mm256_storeu_ps(p0f + 8 * 7, _mm512_castps512_ps256(_f7)); + _mm256_storeu_ps(p0f + 8 * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0f + 8 * 9, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0f + 8 * 10, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0f + 8 * 11, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0f + 8 * 12, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0f + 8 * 13, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0f + 8 * 14, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0f + 8 * 15, _mm512_extractf32x8_ps(_f7, 1)); + p0f += 128; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0f + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0f + out_hstep * 4 + 48, _f7); + p0f += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f4, _f5); + __m512 _tmp3 = _mm512_unpacklo_ps(_f6, _f7); + __m512 _tmp4 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp5 = _mm512_unpackhi_ps(_f2, _f3); + __m512 _tmp6 = _mm512_unpackhi_ps(_f4, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f7); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f2 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f5 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _f6 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + _mm512_storeu_ps(p0f + out_hstep * 4, _f4); + _mm512_storeu_ps(p0f + out_hstep * 5, _f5); + _mm512_storeu_ps(p0f + out_hstep * 6, _f6); + _mm512_storeu_ps(p0f + out_hstep * 7, _f7); + p0f += 16; + } + } } else { - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf4, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf5, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf6, 1)); - _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); - p0 += 128; - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf4, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bf5, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf6, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bf7, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); - p0 += 64; - } - if (out_elempack == 1) - { - __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); - __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); - __m512 _tmp2 = _mm512_unpacklo_ps(_f4, _f5); - __m512 _tmp3 = _mm512_unpacklo_ps(_f6, _f7); - __m512 _tmp4 = _mm512_unpackhi_ps(_f0, _f1); - __m512 _tmp5 = _mm512_unpackhi_ps(_f2, _f3); - __m512 _tmp6 = _mm512_unpackhi_ps(_f4, _f5); - __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f7); - - _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); - _f1 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); - _f2 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); - _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); - _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); - _f5 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); - _f6 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); - _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); - - _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); - _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); - _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); - _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); - _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); - _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); - _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); - _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); - - _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); - _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); - _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); - _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); - _f5 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); - _f6 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); - _f7 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); - - _mm256_storeu_si256((__m256i*)p0, float2bfloat_avx512(_f0)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), float2bfloat_avx512(_f1)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), float2bfloat_avx512(_f2)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), float2bfloat_avx512(_f3)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), float2bfloat_avx512(_f4)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), float2bfloat_avx512(_f5)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), float2bfloat_avx512(_f6)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), float2bfloat_avx512(_f7)); - p0 += 16; + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + __m256i _bf4 = float2bfloat_avx512(_f4); + __m256i _bf5 = float2bfloat_avx512(_f5); + __m256i _bf6 = float2bfloat_avx512(_f6); + __m256i _bf7 = float2bfloat_avx512(_f7); + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_store_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_store_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf4, 0)); + _mm_store_si128((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_store_si128((__m128i*)(p0 + 16 + 8), _mm256_extractf128_si256(_bf4, 1)); + _mm_store_si128((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf1, 0)); + _mm_store_si128((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf5, 0)); + _mm_store_si128((__m128i*)(p0 + 48), _mm256_extractf128_si256(_bf1, 1)); + _mm_store_si128((__m128i*)(p0 + 48 + 8), _mm256_extractf128_si256(_bf5, 1)); + _mm_store_si128((__m128i*)(p0 + 64), _mm256_extractf128_si256(_bf2, 0)); + _mm_store_si128((__m128i*)(p0 + 64 + 8), _mm256_extractf128_si256(_bf6, 0)); + _mm_store_si128((__m128i*)(p0 + 80), _mm256_extractf128_si256(_bf2, 1)); + _mm_store_si128((__m128i*)(p0 + 80 + 8), _mm256_extractf128_si256(_bf6, 1)); + _mm_store_si128((__m128i*)(p0 + 96), _mm256_extractf128_si256(_bf3, 0)); + _mm_store_si128((__m128i*)(p0 + 96 + 8), _mm256_extractf128_si256(_bf7, 0)); + _mm_store_si128((__m128i*)(p0 + 112), _mm256_extractf128_si256(_bf3, 1)); + _mm_store_si128((__m128i*)(p0 + 112 + 8), _mm256_extractf128_si256(_bf7, 1)); + } + if (out_elempack == 8) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); + __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); + __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); + __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); + transpose8x8_epi16(_bf0l, _bf1l, _bf2l, _bf3l, _bf4l, _bf5l, _bf6l, _bf7l); + _mm_storeu_si128((__m128i*)p0, _bf0l); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6l); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7l); + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); + __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); + __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); + __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); + transpose8x8_epi16(_bf0h, _bf1h, _bf2h, _bf3h, _bf4h, _bf5h, _bf6h, _bf7h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _bf0h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _bf1h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 2), _bf2h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 3), _bf3h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 4), _bf4h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 5), _bf5h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 6), _bf6h); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8 * 7), _bf7h); + } + if (out_elempack == 4) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + __m128i _bf4l = _mm256_extractf128_si256(_bf4, 0); + __m128i _bf5l = _mm256_extractf128_si256(_bf5, 0); + __m128i _bf6l = _mm256_extractf128_si256(_bf6, 0); + __m128i _bf7l = _mm256_extractf128_si256(_bf7, 0); + transpose8x4_epi16(_bf0l, _bf1l, _bf2l, _bf3l); + transpose8x4_epi16(_bf4l, _bf5l, _bf6l, _bf7l); + _mm_storel_epi64((__m128i*)p0, _bf0l); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf0l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf1l); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf1l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf2l); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf2l)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf3l); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf3l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf4l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf4l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _bf5l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf5l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _bf6l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf6l)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _bf7l); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7l)); + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _bf4h = _mm256_extractf128_si256(_bf4, 1); + __m128i _bf5h = _mm256_extractf128_si256(_bf5, 1); + __m128i _bf6h = _mm256_extractf128_si256(_bf6, 1); + __m128i _bf7h = _mm256_extractf128_si256(_bf7, 1); + transpose8x4_epi16(_bf0h, _bf1h, _bf2h, _bf3h); + transpose8x4_epi16(_bf4h, _bf5h, _bf6h, _bf7h); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _bf0h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_bf0h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _bf1h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_bf1h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 4), _bf2h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 5), _mm_castsi128_pd(_bf2h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 6), _bf3h); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 7), _mm_castsi128_pd(_bf3h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _bf4h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_bf4h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _bf5h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_bf5h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 4), _bf6h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 5), _mm_castsi128_pd(_bf6h)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 6), _bf7h); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 7), _mm_castsi128_pd(_bf7h)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 15), _mm256_extractf128_si256(_bf7, 1)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storeu_si128((__m128i*)(p0 + 8 * 15), _mm256_extractf128_si256(_bf7, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _mm256_extractf128_si256(_bf4, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _mm256_extractf128_si256(_bf5, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _mm256_extractf128_si256(_bf6, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _mm256_extractf128_si256(_bf7, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 12), _mm256_extractf128_si256(_bf4, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 13), _mm256_extractf128_si256(_bf5, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 14), _mm256_extractf128_si256(_bf6, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 15), _mm256_extractf128_si256(_bf7, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 9), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 10), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 11), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf4, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf5, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf6, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf7, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f4, _f5); + __m512 _tmp3 = _mm512_unpacklo_ps(_f6, _f7); + __m512 _tmp4 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp5 = _mm512_unpackhi_ps(_f2, _f3); + __m512 _tmp6 = _mm512_unpackhi_ps(_f4, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f7); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f2 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f5 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _f6 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_storeu_si256((__m256i*)p0, float2bfloat_avx512(_f0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), float2bfloat_avx512(_f1)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), float2bfloat_avx512(_f2)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), float2bfloat_avx512(_f3)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 4), float2bfloat_avx512(_f4)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 5), float2bfloat_avx512(_f5)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 6), float2bfloat_avx512(_f6)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 7), float2bfloat_avx512(_f7)); + p0 += 16; + } } } } @@ -6229,113 +6912,199 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = _mm256_mul_ps(_f7, _alpha); } - __m128i _bf0 = float2bfloat_avx(_f0); - __m128i _bf1 = float2bfloat_avx(_f1); - __m128i _bf2 = float2bfloat_avx(_f2); - __m128i _bf3 = float2bfloat_avx(_f3); - __m128i _bf4 = float2bfloat_avx(_f4); - __m128i _bf5 = float2bfloat_avx(_f5); - __m128i _bf6 = float2bfloat_avx(_f6); - __m128i _bf7 = float2bfloat_avx(_f7); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 8) + if (output_transpose) { - transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); + if (out_elempack == 8) + { + transpose8x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + _mm256_storeu_ps(p0f + 8 * 2, _f2); + _mm256_storeu_ps(p0f + 8 * 3, _f3); + _mm256_storeu_ps(p0f + 8 * 4, _f4); + _mm256_storeu_ps(p0f + 8 * 5, _f5); + _mm256_storeu_ps(p0f + 8 * 6, _f6); + _mm256_storeu_ps(p0f + 8 * 7, _f7); + } + if (out_elempack == 4) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + transpose8x4_ps(_f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + _mm256_storeu_ps(p0f + 16, _f2); + _mm256_storeu_ps(p0f + 24, _f3); + _mm256_storeu_ps(p0f + out_hstep * 4, _f4); + _mm256_storeu_ps(p0f + out_hstep * 4 + 8, _f5); + _mm256_storeu_ps(p0f + out_hstep * 4 + 16, _f6); + _mm256_storeu_ps(p0f + out_hstep * 4 + 24, _f7); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + out_hstep, _f1); + _mm256_storeu_ps(p0f + out_hstep * 2, _f2); + _mm256_storeu_ps(p0f + out_hstep * 3, _f3); + _mm256_storeu_ps(p0f + out_hstep * 4, _f4); + _mm256_storeu_ps(p0f + out_hstep * 5, _f5); + _mm256_storeu_ps(p0f + out_hstep * 6, _f6); + _mm256_storeu_ps(p0f + out_hstep * 7, _f7); + } + p0f += out_hstep * 8; } - if (out_elempack == 4) - { - transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); - transpose8x4_epi16(_bf4, _bf5, _bf6, _bf7); - - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf1); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf2); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf2)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf3); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf3)); - - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf4); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf4)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _bf5); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf5)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 4), _bf6); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf6)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 6), _bf7); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7)); - } - if (out_elempack == 1) + else { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + _mm256_storeu_ps(p0f + 8 * 2, _f2); + _mm256_storeu_ps(p0f + 8 * 3, _f3); + _mm256_storeu_ps(p0f + 8 * 4, _f4); + _mm256_storeu_ps(p0f + 8 * 5, _f5); + _mm256_storeu_ps(p0f + 8 * 6, _f6); + _mm256_storeu_ps(p0f + 8 * 7, _f7); + p0f += 64; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_f4, _f5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_f6, _f7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp5 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp6 = _mm256_permute2f128_ps(_f4, _f5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_f6, _f7, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(p0f, _tmp0); + _mm256_storeu_ps(p0f + 8, _tmp1); + _mm256_storeu_ps(p0f + 16, _tmp2); + _mm256_storeu_ps(p0f + 24, _tmp3); + _mm256_storeu_ps(p0f + out_hstep * 4, _tmp4); + _mm256_storeu_ps(p0f + out_hstep * 4 + 8, _tmp5); + _mm256_storeu_ps(p0f + out_hstep * 4 + 16, _tmp6); + _mm256_storeu_ps(p0f + out_hstep * 4 + 24, _tmp7); + p0f += 32; + } + if (out_elempack == 1) + { + transpose8x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + out_hstep, _f1); + _mm256_storeu_ps(p0f + out_hstep * 2, _f2); + _mm256_storeu_ps(p0f + out_hstep * 3, _f3); + _mm256_storeu_ps(p0f + out_hstep * 4, _f4); + _mm256_storeu_ps(p0f + out_hstep * 5, _f5); + _mm256_storeu_ps(p0f + out_hstep * 6, _f6); + _mm256_storeu_ps(p0f + out_hstep * 7, _f7); + p0f += 8; + } } - p0 += out_hstep * 8; } else { - if (out_elempack == 8) + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + __m128i _bf2 = float2bfloat_avx(_f2); + __m128i _bf3 = float2bfloat_avx(_f3); + __m128i _bf4 = float2bfloat_avx(_f4); + __m128i _bf5 = float2bfloat_avx(_f5); + __m128i _bf6 = float2bfloat_avx(_f6); + __m128i _bf7 = float2bfloat_avx(_f7); + + if (output_transpose) { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); - _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); - _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); - _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); - _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); - p0 += 64; + if (out_elempack == 8) + { + transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); + } + if (out_elempack == 4) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + transpose8x4_epi16(_bf4, _bf5, _bf6, _bf7); + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 16), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 24), _bf3); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4 + 8), _bf5); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4 + 16), _bf6); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4 + 24), _bf7); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); + } + p0 += out_hstep * 8; } - if (out_elempack == 4) + else { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); - _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf4); - _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _bf5); - _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf6); - _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _bf7); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_bf4)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf5)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_bf6)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7)); - p0 += 32; - } - if (out_elempack == 1) - { - transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + 8 * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + 8 * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + 8 * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + 8 * 7), _bf7); + p0 += 64; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); + _mm_storel_epi64((__m128i*)(p0 + 4 * 4), _bf4); + _mm_storel_epi64((__m128i*)(p0 + 4 * 5), _bf5); + _mm_storel_epi64((__m128i*)(p0 + 4 * 6), _bf6); + _mm_storel_epi64((__m128i*)(p0 + 4 * 7), _bf7); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 4), _mm_castsi128_pd(_bf4)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 5), _mm_castsi128_pd(_bf5)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 6), _mm_castsi128_pd(_bf6)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 7), _mm_castsi128_pd(_bf7)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_epi16(_bf0, _bf1, _bf2, _bf3, _bf4, _bf5, _bf6, _bf7); - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); - p0 += 8; + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 4), _bf4); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 5), _bf5); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 6), _bf6); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 7), _bf7); + p0 += 8; + } } } } @@ -6471,99 +7240,194 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f3 = _mm256_mul_ps(_f3, _alpha); } - __m128i _bf0 = float2bfloat_avx(_f0); - __m128i _bf1 = float2bfloat_avx(_f1); - __m128i _bf2 = float2bfloat_avx(_f2); - __m128i _bf3 = float2bfloat_avx(_f3); - - if (output_transpose) + if (output_elemtype == 1) { + if (output_transpose) + { #if !(defined(__x86_64__) || defined(_M_X64)) #if __AVX__ #if __AVX512F__ - if (out_elempack == 16) - { - transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); - const int jj_m16 = jj % 16; - unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; - _mm_storel_epi64((__m128i*)p1, _bf0); - _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); - _mm_storel_epi64((__m128i*)(p1 + 32), _bf1); - _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_bf1)); - _mm_storel_epi64((__m128i*)(p1 + 64), _bf2); - _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_bf2)); - _mm_storel_epi64((__m128i*)(p1 + 96), _bf3); - _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_bf3)); - } + if (out_elempack == 16) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1f = p0f - out_hstep * jj_m16 + jj_m16; + _mm_storeu_ps(p1f, _mm256_castps256_ps128(_f0)); + _mm_storeu_ps(p1f + 16, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p1f + 32, _mm256_castps256_ps128(_f1)); + _mm_storeu_ps(p1f + 48, _mm256_extractf128_ps(_f1, 1)); + _mm_storeu_ps(p1f + 64, _mm256_castps256_ps128(_f2)); + _mm_storeu_ps(p1f + 80, _mm256_extractf128_ps(_f2, 1)); + _mm_storeu_ps(p1f + 96, _mm256_castps256_ps128(_f3)); + _mm_storeu_ps(p1f + 112, _mm256_extractf128_ps(_f3, 1)); + } #endif // __AVX512F__ - if (out_elempack == 8) - { - transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); - const int jj_m8 = jj % 8; - unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; - _mm_storel_epi64((__m128i*)p1, _bf0); - _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); - _mm_storel_epi64((__m128i*)(p1 + 16), _bf1); - _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_bf1)); - _mm_storel_epi64((__m128i*)(p1 + 32), _bf2); - _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_bf2)); - _mm_storel_epi64((__m128i*)(p1 + 48), _bf3); - _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_bf3)); - } + if (out_elempack == 8) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1f = p0f - out_hstep * jj_m8 + jj_m8; + _mm_storeu_ps(p1f, _mm256_castps256_ps128(_f0)); + _mm_storeu_ps(p1f + 8, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p1f + 16, _mm256_castps256_ps128(_f1)); + _mm_storeu_ps(p1f + 24, _mm256_extractf128_ps(_f1, 1)); + _mm_storeu_ps(p1f + 32, _mm256_castps256_ps128(_f2)); + _mm_storeu_ps(p1f + 40, _mm256_extractf128_ps(_f2, 1)); + _mm_storeu_ps(p1f + 48, _mm256_castps256_ps128(_f3)); + _mm_storeu_ps(p1f + 56, _mm256_extractf128_ps(_f3, 1)); + } #endif // __AVX__ #endif // !(defined(__x86_64__) || defined(_M_X64)) - if (out_elempack == 4) - { - transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - _mm_storeu_si128((__m128i*)(p0 + 16), _bf2); - _mm_storeu_si128((__m128i*)(p0 + 24), _bf3); + if (out_elempack == 4) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + _mm256_storeu_ps(p0f + 16, _f2); + _mm256_storeu_ps(p0f + 24, _f3); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + out_hstep, _f1); + _mm256_storeu_ps(p0f + out_hstep * 2, _f2); + _mm256_storeu_ps(p0f + out_hstep * 3, _f3); + } + p0f += out_hstep * 4; } - if (out_elempack == 1) + else { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + _mm256_storeu_ps(p0f + 8 * 2, _f2); + _mm256_storeu_ps(p0f + 8 * 3, _f3); + p0f += 32; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm256_castps256_ps128(_f0)); + _mm_storeu_ps(p0f + 4, _mm256_castps256_ps128(_f1)); + _mm_storeu_ps(p0f + 4 * 2, _mm256_castps256_ps128(_f2)); + _mm_storeu_ps(p0f + 4 * 3, _mm256_castps256_ps128(_f3)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _mm256_extractf128_ps(_f1, 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 2, _mm256_extractf128_ps(_f2, 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4 * 3, _mm256_extractf128_ps(_f3, 1)); + p0f += 16; + } + if (out_elempack == 1) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0f, _mm256_extractf128_ps(_f0, 0)); + _mm_storeu_ps(p0f + out_hstep, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 2, _mm256_extractf128_ps(_f1, 0)); + _mm_storeu_ps(p0f + out_hstep * 3, _mm256_extractf128_ps(_f1, 1)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f2, 0)); + _mm_storeu_ps(p0f + out_hstep * 5, _mm256_extractf128_ps(_f2, 1)); + _mm_storeu_ps(p0f + out_hstep * 6, _mm256_extractf128_ps(_f3, 0)); + _mm_storeu_ps(p0f + out_hstep * 7, _mm256_extractf128_ps(_f3, 1)); + p0f += 4; + } } - p0 += out_hstep * 4; } else { - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); - _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); - p0 += 32; - } - if (out_elempack == 4) + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + __m128i _bf2 = float2bfloat_avx(_f2); + __m128i _bf3 = float2bfloat_avx(_f3); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); - p0 += 16; +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p1 + 32), _bf1); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p1 + 64), _bf2); + _mm_storeh_pd((double*)(p1 + 80), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p1 + 96), _bf3); + _mm_storeh_pd((double*)(p1 + 112), _mm_castsi128_pd(_bf3)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p1 + 16), _bf1); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p1 + 32), _bf2); + _mm_storeh_pd((double*)(p1 + 40), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p1 + 48), _bf3); + _mm_storeh_pd((double*)(p1 + 56), _mm_castsi128_pd(_bf3)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 16), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 24), _bf3); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _bf3); + } + p0 += out_hstep * 4; } - if (out_elempack == 1) + else { - transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + _mm_storeu_si128((__m128i*)(p0 + 8 * 2), _bf2); + _mm_storeu_si128((__m128i*)(p0 + 8 * 3), _bf3); + p0 += 32; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf2); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 2), _mm_castsi128_pd(_bf2)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_bf3)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_epi16(_bf0, _bf1, _bf2, _bf3); - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf1); - _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf2); - _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf2)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _bf3); - _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf3)); - p0 += 4; + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _bf2); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 6), _bf3); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf3)); + p0 += 4; + } } } } @@ -6652,70 +7516,145 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm256_mul_ps(_f1, _alpha); } - __m128i _bf0 = float2bfloat_avx(_f0); - __m128i _bf1 = float2bfloat_avx(_f1); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - } - if (out_elempack == 4) + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm256_castps256_ps128(_f0)); + _mm_storeu_ps(p0f + 4, _mm256_castps256_ps128(_f1)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _mm256_extractf128_ps(_f1, 1)); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + out_hstep, _f1); + } + p0f += out_hstep * 2; } - if (out_elempack == 1) + else { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f0); + _mm256_storeu_ps(p0f + 8, _f1); + p0f += 16; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm256_castps256_ps128(_f0)); + _mm_storeu_ps(p0f + 4, _mm256_castps256_ps128(_f1)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _mm256_extractf128_ps(_f1, 1)); + p0f += 8; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(out_hstep)); + _mm256_i32scatter_ps(p0f, _vindex, _f0, sizeof(float)); + _mm256_i32scatter_ps(p0f + 1, _vindex, _f1, sizeof(float)); +#else + float sum0[8]; + float sum1[8]; + _mm256_storeu_ps(sum0, _f0); + _mm256_storeu_ps(sum1, _f1); + + p0f[0] = sum0[0]; + p0f[1] = sum1[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep + 1] = sum1[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 2 + 1] = sum1[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 3 + 1] = sum1[3]; + p0f[out_hstep * 4] = sum0[4]; + p0f[out_hstep * 4 + 1] = sum1[4]; + p0f[out_hstep * 5] = sum0[5]; + p0f[out_hstep * 5 + 1] = sum1[5]; + p0f[out_hstep * 6] = sum0[6]; + p0f[out_hstep * 6 + 1] = sum1[6]; + p0f[out_hstep * 7] = sum0[7]; + p0f[out_hstep * 7 + 1] = sum1[7]; +#endif // __AVX512F__ + p0f += 2; + } } - p0 += out_hstep * 2; } else { - if (out_elempack == 8) + __m128i _bf0 = float2bfloat_avx(_f0); + __m128i _bf1 = float2bfloat_avx(_f1); + + if (output_transpose) { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); - p0 += 16; + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + } + p0 += out_hstep * 2; } - if (out_elempack == 4) + else { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); - p0 += 8; - } - if (out_elempack == 1) - { - unsigned short sum0[8]; - unsigned short sum1[8]; - _mm_storeu_si128((__m128i*)sum0, _bf0); - _mm_storeu_si128((__m128i*)sum1, _bf1); - - p0[0] = sum0[0]; - p0[1] = sum1[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep + 1] = sum1[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 2 + 1] = sum1[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 3 + 1] = sum1[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 4 + 1] = sum1[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 5 + 1] = sum1[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 6 + 1] = sum1[6]; - p0[out_hstep * 7] = sum0[7]; - p0[out_hstep * 7 + 1] = sum1[7]; - p0 += 2; + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + p0 += 8; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + unsigned short sum1[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + _mm_storeu_si128((__m128i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + p0 += 2; + } } } } @@ -6773,52 +7712,108 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f = _mm256_mul_ps(_f, _mm256_set1_ps(alpha)); } - __m128i _bf = float2bfloat_avx(_f); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 8) + if (output_transpose) { - _mm_storeu_si128((__m128i*)p0, _bf); - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _bf); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm256_castps256_ps128(_f)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f, 1)); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0f, _f); + } + p0f += out_hstep; } - if (out_elempack == 1) + else { - _mm_storeu_si128((__m128i*)p0, _bf); + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _f); + p0f += 8; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm256_castps256_ps128(_f)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_f, 1)); + p0f += 4; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(out_hstep)); + _mm256_i32scatter_ps(p0f, _vindex, _f, sizeof(float)); +#else + float sum0[8]; + _mm256_storeu_ps(sum0, _f); + p0f[0] = sum0[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 4] = sum0[4]; + p0f[out_hstep * 5] = sum0[5]; + p0f[out_hstep * 6] = sum0[6]; + p0f[out_hstep * 7] = sum0[7]; +#endif // __AVX512F__ + p0f++; + } } - p0 += out_hstep; } else { - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _bf); - p0 += 8; - } - if (out_elempack == 4) + __m128i _bf = float2bfloat_avx(_f); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); - p0 += 4; + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); + } + if (out_elempack == 1) + { + _mm_storeu_si128((__m128i*)p0, _bf); + } + p0 += out_hstep; } - if (out_elempack == 1) + else { - unsigned short sum0[8]; - _mm_storeu_si128((__m128i*)sum0, _bf); + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf)); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf); - p0[0] = sum0[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 7] = sum0[7]; - p0++; + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + p0++; + } } } } @@ -6827,13 +7822,16 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& for (; ii + 3 < max_ii; ii += 4) { unsigned short* p0; + float* p0f; if (output_transpose) { p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + p0f = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; } else { p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + p0f = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; } __m128 _c0 = _mm_set1_ps(0.f); @@ -7002,159 +8000,273 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f3 = _mm512_mul_ps(_f3, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - __m256i _bf2 = float2bfloat_avx512(_f2); - __m256i _bf3 = float2bfloat_avx512(_f3); - - if (output_transpose) - { - if (out_elempack == 16) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 12), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + 16 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + 32 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + 32 + 12), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + 48 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + 48 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + 48 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } - if (out_elempack == 8) - { - transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + 16 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16 + 4), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } - if (out_elempack == 4) - { - __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); - __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); - __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); - __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); - - __m128i _t0 = _mm_unpacklo_epi16(_bf0l, _bf1l); - __m128i _t1 = _mm_unpacklo_epi16(_bf2l, _bf3l); - __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); - __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); - _mm_storel_epi64((__m128i*)p0, _d0); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); - - __m128i _t2 = _mm_unpackhi_epi16(_bf0l, _bf1l); - __m128i _t3 = _mm_unpackhi_epi16(_bf2l, _bf3l); - __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); - __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); - - __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); - __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); - __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); - __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); - __m128i _t4 = _mm_unpacklo_epi16(_bf0h, _bf1h); - __m128i _t5 = _mm_unpacklo_epi16(_bf2h, _bf3h); - __m128i _d4 = _mm_unpacklo_epi32(_t4, _t5); - __m128i _d5 = _mm_unpackhi_epi32(_t4, _t5); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _d4); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_d4)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _d5); - _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_d5)); - - __m128i _t6 = _mm_unpackhi_epi16(_bf0h, _bf1h); - __m128i _t7 = _mm_unpackhi_epi16(_bf2h, _bf3h); - __m128i _d6 = _mm_unpacklo_epi32(_t6, _t7); - __m128i _d7 = _mm_unpackhi_epi32(_t6, _t7); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _d6); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_d6)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _d7); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_d7)); - } - if (out_elempack == 1) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - } - p0 += out_hstep * 16; + if (output_elemtype == 1) + { + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_store_ps(p0f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0f + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0f + 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_store_ps(p0f + 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_store_ps(p0f + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0f + 16 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0f + 16 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_store_ps(p0f + 16 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_store_ps(p0f + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0f + 32 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0f + 32 + 8, _mm512_extractf32x4_ps(_f2, 2)); + _mm_store_ps(p0f + 32 + 12, _mm512_extractf32x4_ps(_f3, 2)); + _mm_store_ps(p0f + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0f + 48 + 4, _mm512_extractf32x4_ps(_f1, 3)); + _mm_store_ps(p0f + 48 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_store_ps(p0f + 48 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 8) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_store_ps(p0f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0f + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0f + 8, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0f + 12, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0f + 16, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0f + 16 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0f + 16 + 8, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0f + 16 + 12, _mm512_extractf32x4_ps(_f1, 3)); + _mm_store_ps(p0f + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_store_ps(p0f + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f3, 0)); + _mm_store_ps(p0f + out_hstep * 8 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_store_ps(p0f + out_hstep * 8 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_store_ps(p0f + out_hstep * 8 + 16, _mm512_extractf32x4_ps(_f2, 2)); + _mm_store_ps(p0f + out_hstep * 8 + 16 + 4, _mm512_extractf32x4_ps(_f3, 2)); + _mm_store_ps(p0f + out_hstep * 8 + 16 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_store_ps(p0f + out_hstep * 8 + 16 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep * 4, _f1); + _mm512_storeu_ps(p0f + out_hstep * 8, _f2); + _mm512_storeu_ps(p0f + out_hstep * 12, _f3); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0f + out_hstep, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0f + out_hstep * 2, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0f + out_hstep * 3, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0f + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0f + out_hstep * 6, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0f + out_hstep * 7, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0f + out_hstep * 9, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0f + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0f + out_hstep * 11, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0f + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0f + out_hstep * 13, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0f + out_hstep * 14, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0f + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + } + p0f += out_hstep * 16; + } + else + { + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + 16, _f1); + _mm512_storeu_ps(p0f + 32, _f2); + _mm512_storeu_ps(p0f + 48, _f3); + p0f += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f3); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + _mm512_storeu_ps(p0f + out_hstep * 2, _f2); + _mm512_storeu_ps(p0f + out_hstep * 3, _f3); + p0f += 16; + } + } } else { - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); - _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); - _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); - _mm_storeh_pd((double*)(p0 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - _mm_storeh_pd((double*)(p0 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); - _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); - p0 += 64; - } - if (out_elempack == 1) - { - __m512i _idx_r0r1 = _mm512_set_epi16(61, 45, 29, 13, 57, 41, 25, 9, 53, 37, 21, 5, 49, 33, 17, 1, 60, 44, 28, 12, 56, 40, 24, 8, 52, 36, 20, 4, 48, 32, 16, 0); - __m512i _idx_r2r3 = _mm512_set_epi16(63, 47, 31, 15, 59, 43, 27, 11, 55, 39, 23, 7, 51, 35, 19, 3, 62, 46, 30, 14, 58, 42, 26, 10, 54, 38, 22, 6, 50, 34, 18, 2); - - __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); - __m512i _bf23 = combine8x2_epi32(_bf2, _bf3); - - __m512i _t01 = _mm512_permutex2var_epi16(_bf01, _idx_r0r1, _bf23); - __m512i _t23 = _mm512_permutex2var_epi16(_bf01, _idx_r2r3, _bf23); - - _mm256_storeu_si256((__m256i*)p0, _mm512_extracti32x8_epi32(_t01, 0)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _mm512_extracti32x8_epi32(_t01, 1)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _mm512_extracti32x8_epi32(_t23, 0)); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _mm512_extracti32x8_epi32(_t23, 1)); - p0 += 16; + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + __m256i _bf2 = float2bfloat_avx512(_f2); + __m256i _bf3 = float2bfloat_avx512(_f3); + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 12), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 16), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 32), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 8), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 32 + 12), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 48), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + 48 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + if (out_elempack == 8) + { + transpose16x4_epi16(_bf0, _bf1, _bf2, _bf3); + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + 16), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 16 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 16 + 4), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 8), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 16 + 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + if (out_elempack == 4) + { + __m128i _bf0l = _mm256_extractf128_si256(_bf0, 0); + __m128i _bf1l = _mm256_extractf128_si256(_bf1, 0); + __m128i _bf2l = _mm256_extractf128_si256(_bf2, 0); + __m128i _bf3l = _mm256_extractf128_si256(_bf3, 0); + + __m128i _t0 = _mm_unpacklo_epi16(_bf0l, _bf1l); + __m128i _t1 = _mm_unpacklo_epi16(_bf2l, _bf3l); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storel_epi64((__m128i*)p0, _d0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); + + __m128i _t2 = _mm_unpackhi_epi16(_bf0l, _bf1l); + __m128i _t3 = _mm_unpackhi_epi16(_bf2l, _bf3l); + __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); + __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); + + __m128i _bf0h = _mm256_extractf128_si256(_bf0, 1); + __m128i _bf1h = _mm256_extractf128_si256(_bf1, 1); + __m128i _bf2h = _mm256_extractf128_si256(_bf2, 1); + __m128i _bf3h = _mm256_extractf128_si256(_bf3, 1); + __m128i _t4 = _mm_unpacklo_epi16(_bf0h, _bf1h); + __m128i _t5 = _mm_unpacklo_epi16(_bf2h, _bf3h); + __m128i _d4 = _mm_unpacklo_epi32(_t4, _t5); + __m128i _d5 = _mm_unpackhi_epi32(_t4, _t5); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _d4); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4), _mm_castsi128_pd(_d4)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4 * 2), _d5); + _mm_storeh_pd((double*)(p0 + out_hstep * 8 + 4 * 3), _mm_castsi128_pd(_d5)); + + __m128i _t6 = _mm_unpackhi_epi16(_bf0h, _bf1h); + __m128i _t7 = _mm_unpackhi_epi16(_bf2h, _bf3h); + __m128i _d6 = _mm_unpacklo_epi32(_t6, _t7); + __m128i _d7 = _mm_unpackhi_epi32(_t6, _t7); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12), _d6); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_d6)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 12 + 4 * 2), _d7); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4 * 3), _mm_castsi128_pd(_d7)); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _mm256_extractf128_si256(_bf2, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _mm256_extractf128_si256(_bf3, 0)); + _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 0))); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 0))); + _mm_storel_epi64((__m128i*)(p0 + 4 * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 9), _mm256_extractf128_si256(_bf1, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 10), _mm256_extractf128_si256(_bf2, 1)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 11), _mm256_extractf128_si256(_bf3, 1)); + _mm_storeh_pd((double*)(p0 + 4 * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 13), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 14), _mm_castsi128_pd(_mm256_extractf128_si256(_bf2, 1))); + _mm_storeh_pd((double*)(p0 + 4 * 15), _mm_castsi128_pd(_mm256_extractf128_si256(_bf3, 1))); + p0 += 64; + } + if (out_elempack == 1) + { + __m512i _idx_r0r1 = _mm512_set_epi16(61, 45, 29, 13, 57, 41, 25, 9, 53, 37, 21, 5, 49, 33, 17, 1, 60, 44, 28, 12, 56, 40, 24, 8, 52, 36, 20, 4, 48, 32, 16, 0); + __m512i _idx_r2r3 = _mm512_set_epi16(63, 47, 31, 15, 59, 43, 27, 11, 55, 39, 23, 7, 51, 35, 19, 3, 62, 46, 30, 14, 58, 42, 26, 10, 54, 38, 22, 6, 50, 34, 18, 2); + + __m512i _bf01 = combine8x2_epi32(_bf0, _bf1); + __m512i _bf23 = combine8x2_epi32(_bf2, _bf3); + + __m512i _t01 = _mm512_permutex2var_epi16(_bf01, _idx_r0r1, _bf23); + __m512i _t23 = _mm512_permutex2var_epi16(_bf01, _idx_r2r3, _bf23); + + _mm256_storeu_si256((__m256i*)p0, _mm512_extracti32x8_epi32(_t01, 0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _mm512_extracti32x8_epi32(_t01, 1)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 2), _mm512_extracti32x8_epi32(_t23, 0)); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep * 3), _mm512_extracti32x8_epi32(_t23, 1)); + p0 += 16; + } } } } @@ -7346,106 +8458,220 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = _mm_mul_ps(_f7, _alpha); } - __m128i _bf04 = float2bfloat_sse(_f0, _f4); - __m128i _bf15 = float2bfloat_sse(_f1, _f5); - __m128i _bf26 = float2bfloat_sse(_f2, _f6); - __m128i _bf37 = float2bfloat_sse(_f3, _f7); - - if (output_transpose) + if (output_elemtype == 1) { + if (output_transpose) + { #if __AVX__ - if (out_elempack == 8) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); - __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); - __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); - __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); - _bf04 = _mm_unpacklo_epi32(_t0, _t1); - _bf15 = _mm_unpacklo_epi32(_t2, _t3); - _bf26 = _mm_unpackhi_epi32(_t0, _t1); - _bf37 = _mm_unpackhi_epi32(_t2, _t3); - _t0 = _mm_unpacklo_epi64(_bf04, _bf15); - _t1 = _mm_unpackhi_epi64(_bf04, _bf15); - _t2 = _mm_unpacklo_epi64(_bf26, _bf37); - _t3 = _mm_unpackhi_epi64(_bf26, _bf37); - - _mm_storel_epi64((__m128i*)p0, _t0); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_t0)); - _mm_storel_epi64((__m128i*)(p0 + 8), _t1); - _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_t1)); - _mm_storel_epi64((__m128i*)(p0 + 16), _t2); - _mm_storeh_pd((double*)(p0 + 20), _mm_castsi128_pd(_t2)); - _mm_storel_epi64((__m128i*)(p0 + 24), _t3); - _mm_storeh_pd((double*)(p0 + 28), _mm_castsi128_pd(_t3)); - } + if (out_elempack == 8) + { + float tmp0[4], tmp1[4], tmp2[4], tmp3[4]; + float tmp4[4], tmp5[4], tmp6[4], tmp7[4]; + _mm_storeu_ps(tmp0, _f0); + _mm_storeu_ps(tmp1, _f1); + _mm_storeu_ps(tmp2, _f2); + _mm_storeu_ps(tmp3, _f3); + _mm_storeu_ps(tmp4, _f4); + _mm_storeu_ps(tmp5, _f5); + _mm_storeu_ps(tmp6, _f6); + _mm_storeu_ps(tmp7, _f7); + for (int k = 0; k < 4; k++) + { + p0f[k * 8] = tmp0[k]; + p0f[k * 8 + 1] = tmp1[k]; + p0f[k * 8 + 2] = tmp2[k]; + p0f[k * 8 + 3] = tmp3[k]; + p0f[k * 8 + 4] = tmp4[k]; + p0f[k * 8 + 5] = tmp5[k]; + p0f[k * 8 + 6] = tmp6[k]; + p0f[k * 8 + 7] = tmp7[k]; + } + } #endif // __AVX__ - if (out_elempack == 4) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); - __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); - __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); - __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); - _mm_storel_epi64((__m128i*)p0, _d0); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); - - __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); - __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); - __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); - __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); - } - if (out_elempack == 1) - { - _mm_storel_epi64((__m128i*)p0, _bf04); - _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf15); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf26); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _bf37); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf04)); - _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf15)); - _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_bf26)); - _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf37)); - } - p0 += out_hstep * 8; + if (out_elempack == 4) + { + float tmp0[4], tmp1[4], tmp2[4], tmp3[4]; + float tmp4[4], tmp5[4], tmp6[4], tmp7[4]; + _mm_storeu_ps(tmp0, _f0); + _mm_storeu_ps(tmp1, _f1); + _mm_storeu_ps(tmp2, _f2); + _mm_storeu_ps(tmp3, _f3); + _mm_storeu_ps(tmp4, _f4); + _mm_storeu_ps(tmp5, _f5); + _mm_storeu_ps(tmp6, _f6); + _mm_storeu_ps(tmp7, _f7); + for (int k = 0; k < 4; k++) + { + p0f[k * 4] = tmp0[k]; + p0f[k * 4 + 1] = tmp1[k]; + p0f[k * 4 + 2] = tmp2[k]; + p0f[k * 4 + 3] = tmp3[k]; + } + for (int k = 0; k < 4; k++) + { + p0f[out_hstep * 4 + k * 4] = tmp4[k]; + p0f[out_hstep * 4 + k * 4 + 1] = tmp5[k]; + p0f[out_hstep * 4 + k * 4 + 2] = tmp6[k]; + p0f[out_hstep * 4 + k * 4 + 3] = tmp7[k]; + } + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep, _f1); + _mm_storeu_ps(p0f + out_hstep * 2, _f2); + _mm_storeu_ps(p0f + out_hstep * 3, _f3); + _mm_storeu_ps(p0f + out_hstep * 4, _f4); + _mm_storeu_ps(p0f + out_hstep * 5, _f5); + _mm_storeu_ps(p0f + out_hstep * 6, _f6); + _mm_storeu_ps(p0f + out_hstep * 7, _f7); + } + p0f += out_hstep * 8; + } + else + { + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + _mm_storeu_ps(p0f + 4 * 2, _f2); + _mm_storeu_ps(p0f + 4 * 3, _f3); + _mm_storeu_ps(p0f + 4 * 4, _f4); + _mm_storeu_ps(p0f + 4 * 5, _f5); + _mm_storeu_ps(p0f + 4 * 6, _f6); + _mm_storeu_ps(p0f + 4 * 7, _f7); + p0f += 32; + } + if (out_elempack == 1) + { + float tmp0[4], tmp1[4], tmp2[4], tmp3[4]; + float tmp4[4], tmp5[4], tmp6[4], tmp7[4]; + _mm_storeu_ps(tmp0, _f0); + _mm_storeu_ps(tmp1, _f1); + _mm_storeu_ps(tmp2, _f2); + _mm_storeu_ps(tmp3, _f3); + _mm_storeu_ps(tmp4, _f4); + _mm_storeu_ps(tmp5, _f5); + _mm_storeu_ps(tmp6, _f6); + _mm_storeu_ps(tmp7, _f7); + for (int k = 0; k < 4; k++) + { + p0f[out_hstep * k] = tmp0[k]; + p0f[out_hstep * k + 1] = tmp1[k]; + p0f[out_hstep * k + 2] = tmp2[k]; + p0f[out_hstep * k + 3] = tmp3[k]; + p0f[out_hstep * k + 4] = tmp4[k]; + p0f[out_hstep * k + 5] = tmp5[k]; + p0f[out_hstep * k + 6] = tmp6[k]; + p0f[out_hstep * k + 7] = tmp7[k]; + } + p0f += 8; + } + } } else { - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _bf04); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf15); - _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf26); - _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf37); - _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_bf04)); - _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf15)); - _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_bf26)); - _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf37)); - p0 += 32; - } - if (out_elempack == 1) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); - __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); - __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); - __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); - _bf04 = _mm_unpacklo_epi32(_t0, _t1); - _bf15 = _mm_unpacklo_epi32(_t2, _t3); - _bf26 = _mm_unpackhi_epi32(_t0, _t1); - _bf37 = _mm_unpackhi_epi32(_t2, _t3); - _t0 = _mm_unpacklo_epi64(_bf04, _bf15); - _t1 = _mm_unpackhi_epi64(_bf04, _bf15); - _t2 = _mm_unpacklo_epi64(_bf26, _bf37); - _t3 = _mm_unpackhi_epi64(_bf26, _bf37); - - _mm_storeu_si128((__m128i*)p0, _t0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _t1); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _t2); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _t3); - p0 += 8; + __m128i _bf04 = float2bfloat_sse(_f0, _f4); + __m128i _bf15 = float2bfloat_sse(_f1, _f5); + __m128i _bf26 = float2bfloat_sse(_f2, _f6); + __m128i _bf37 = float2bfloat_sse(_f3, _f7); + + if (output_transpose) + { +#if __AVX__ + if (out_elempack == 8) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + _bf04 = _mm_unpacklo_epi32(_t0, _t1); + _bf15 = _mm_unpacklo_epi32(_t2, _t3); + _bf26 = _mm_unpackhi_epi32(_t0, _t1); + _bf37 = _mm_unpackhi_epi32(_t2, _t3); + _t0 = _mm_unpacklo_epi64(_bf04, _bf15); + _t1 = _mm_unpackhi_epi64(_bf04, _bf15); + _t2 = _mm_unpacklo_epi64(_bf26, _bf37); + _t3 = _mm_unpackhi_epi64(_bf26, _bf37); + + _mm_storel_epi64((__m128i*)p0, _t0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_t0)); + _mm_storel_epi64((__m128i*)(p0 + 8), _t1); + _mm_storeh_pd((double*)(p0 + 12), _mm_castsi128_pd(_t1)); + _mm_storel_epi64((__m128i*)(p0 + 16), _t2); + _mm_storeh_pd((double*)(p0 + 20), _mm_castsi128_pd(_t2)); + _mm_storel_epi64((__m128i*)(p0 + 24), _t3); + _mm_storeh_pd((double*)(p0 + 28), _mm_castsi128_pd(_t3)); + } +#endif // __AVX__ + if (out_elempack == 4) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storel_epi64((__m128i*)p0, _d0); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _d1); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_d1)); + + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + __m128i _d2 = _mm_unpacklo_epi32(_t2, _t3); + __m128i _d3 = _mm_unpackhi_epi32(_t2, _t3); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), _d2); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_d2)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4 + 4 * 2), _d3); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4 * 3), _mm_castsi128_pd(_d3)); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf04); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf15); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf26); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 3), _bf37); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf04)); + _mm_storeh_pd((double*)(p0 + out_hstep * 5), _mm_castsi128_pd(_bf15)); + _mm_storeh_pd((double*)(p0 + out_hstep * 6), _mm_castsi128_pd(_bf26)); + _mm_storeh_pd((double*)(p0 + out_hstep * 7), _mm_castsi128_pd(_bf37)); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf04); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf15); + _mm_storel_epi64((__m128i*)(p0 + 4 * 2), _bf26); + _mm_storel_epi64((__m128i*)(p0 + 4 * 3), _bf37); + _mm_storeh_pd((double*)(p0 + 4 * 4), _mm_castsi128_pd(_bf04)); + _mm_storeh_pd((double*)(p0 + 4 * 5), _mm_castsi128_pd(_bf15)); + _mm_storeh_pd((double*)(p0 + 4 * 6), _mm_castsi128_pd(_bf26)); + _mm_storeh_pd((double*)(p0 + 4 * 7), _mm_castsi128_pd(_bf37)); + p0 += 32; + } + if (out_elempack == 1) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf04, _bf15); + __m128i _t1 = _mm_unpacklo_epi16(_bf26, _bf37); + __m128i _t2 = _mm_unpackhi_epi16(_bf04, _bf15); + __m128i _t3 = _mm_unpackhi_epi16(_bf26, _bf37); + _bf04 = _mm_unpacklo_epi32(_t0, _t1); + _bf15 = _mm_unpacklo_epi32(_t2, _t3); + _bf26 = _mm_unpackhi_epi32(_t0, _t1); + _bf37 = _mm_unpackhi_epi32(_t2, _t3); + _t0 = _mm_unpacklo_epi64(_bf04, _bf15); + _t1 = _mm_unpackhi_epi64(_bf04, _bf15); + _t2 = _mm_unpacklo_epi64(_bf26, _bf37); + _t3 = _mm_unpackhi_epi64(_bf26, _bf37); + + _mm_storeu_si128((__m128i*)p0, _t0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _t1); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 2), _t2); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 3), _t3); + p0 += 8; + } } } } @@ -7552,83 +8778,154 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f3 = _mm_mul_ps(_f3, _alpha); } - __m128i _bf02 = float2bfloat_sse(_f0, _f2); - __m128i _bf13 = float2bfloat_sse(_f1, _f3); - - if (output_transpose) + if (output_elemtype == 1) { + if (output_transpose) + { #if !(defined(__x86_64__) || defined(_M_X64)) #if __AVX__ #if __AVX512F__ - if (out_elempack == 16) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); - __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); - __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); - __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); - const int jj_m16 = jj % 16; - unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; - _mm_storel_epi64((__m128i*)p1, _d0); - _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_d0)); - _mm_storel_epi64((__m128i*)(p1 + 32), _d1); - _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_d1)); - } + if (out_elempack == 16) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1f = p0f - out_hstep * jj_m16 + jj_m16; + _mm_store_ps(p1f, _f0); + _mm_store_ps(p1f + 16, _f1); + _mm_store_ps(p1f + 32, _f2); + _mm_store_ps(p1f + 48, _f3); + } #endif // __AVX512F__ - if (out_elempack == 8) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); - __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); - __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); - __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); - const int jj_m8 = jj % 8; - unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; - _mm_storel_epi64((__m128i*)p1, _d0); - _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_d0)); - _mm_storel_epi64((__m128i*)(p1 + 16), _d1); - _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_d1)); - } + if (out_elempack == 8) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1f = p0f - out_hstep * jj_m8 + jj_m8; + _mm_store_ps(p1f, _f0); + _mm_store_ps(p1f + 8, _f1); + _mm_store_ps(p1f + 16, _f2); + _mm_store_ps(p1f + 24, _f3); + } #endif // __AVX__ #endif // !(defined(__x86_64__) || defined(_M_X64)) - if (out_elempack == 4) - { - __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); - __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); - __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); - __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); - _mm_storeu_si128((__m128i*)p0, _d0); - _mm_storeu_si128((__m128i*)(p0 + 8), _d1); + if (out_elempack == 4) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _mm_store_ps(p0f, _f0); + _mm_store_ps(p0f + 4, _f1); + _mm_store_ps(p0f + 8, _f2); + _mm_store_ps(p0f + 12, _f3); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep, _f1); + _mm_storeu_ps(p0f + out_hstep * 2, _f2); + _mm_storeu_ps(p0f + out_hstep * 3, _f3); + } + p0f += out_hstep * 4; } - if (out_elempack == 1) + else { - _mm_storel_epi64((__m128i*)p0, _bf02); - _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf13); - _mm_storeh_pd((double*)(p0 + out_hstep * 2), _mm_castsi128_pd(_bf02)); - _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); - } - p0 += out_hstep * 4; + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + _mm_storeu_ps(p0f + 4 * 2, _f2); + _mm_storeu_ps(p0f + 4 * 3, _f3); + p0f += 16; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep, _f1); + _mm_storeu_ps(p0f + out_hstep * 2, _f2); + _mm_storeu_ps(p0f + out_hstep * 3, _f3); + p0f += 4; + } + } } else { - if (out_elempack == 4) + __m128i _bf02 = float2bfloat_sse(_f0, _f2); + __m128i _bf13 = float2bfloat_sse(_f1, _f3); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf02); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf13); - _mm_storeh_pd((double*)(p0 + 4 * 2), _mm_castsi128_pd(_bf02)); - _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf13)); - p0 += 16; +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _d0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p1 + 32), _d1); + _mm_storeh_pd((double*)(p1 + 48), _mm_castsi128_pd(_d1)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _d0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_d0)); + _mm_storel_epi64((__m128i*)(p1 + 16), _d1); + _mm_storeh_pd((double*)(p1 + 24), _mm_castsi128_pd(_d1)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + __m128i _d0 = _mm_unpacklo_epi32(_t0, _t1); + __m128i _d1 = _mm_unpackhi_epi32(_t0, _t1); + _mm_storeu_si128((__m128i*)p0, _d0); + _mm_storeu_si128((__m128i*)(p0 + 8), _d1); + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf02); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _bf13); + _mm_storeh_pd((double*)(p0 + out_hstep * 2), _mm_castsi128_pd(_bf02)); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); + } + p0 += out_hstep * 4; } - if (out_elempack == 1) + else { - __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); - __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); - _bf02 = _mm_unpacklo_epi32(_t0, _t1); - _bf13 = _mm_unpackhi_epi32(_t0, _t1); - - _mm_storel_epi64((__m128i*)(p0), _bf02); - _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf02)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf13); - _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); - p0 += 4; + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf02); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf13); + _mm_storeh_pd((double*)(p0 + 4 * 2), _mm_castsi128_pd(_bf02)); + _mm_storeh_pd((double*)(p0 + 4 * 3), _mm_castsi128_pd(_bf13)); + p0 += 16; + } + if (out_elempack == 1) + { + __m128i _t0 = _mm_unpacklo_epi16(_bf02, _bf13); + __m128i _t1 = _mm_unpackhi_epi16(_bf02, _bf13); + _bf02 = _mm_unpacklo_epi32(_t0, _t1); + _bf13 = _mm_unpackhi_epi32(_t0, _t1); + + _mm_storel_epi64((__m128i*)(p0), _bf02); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf02)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 2), _bf13); + _mm_storeh_pd((double*)(p0 + out_hstep * 3), _mm_castsi128_pd(_bf13)); + p0 += 4; + } } } } @@ -7703,53 +9000,108 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm_mul_ps(_f1, _alpha); } - __m128i _bf01 = float2bfloat_sse(_f0, _f1); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 4) + if (output_transpose) { - unsigned short sum0[8]; - _mm_storeu_si128((__m128i*)sum0, _bf01); - - p0[0] = sum0[0]; - p0[1] = sum0[4]; - p0[4] = sum0[1]; - p0[5] = sum0[5]; - p0[8] = sum0[2]; - p0[9] = sum0[6]; - p0[12] = sum0[3]; - p0[13] = sum0[7]; + if (out_elempack == 4) + { + float sum0[8]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum0 + 4, _f1); + + p0f[0] = sum0[0]; + p0f[1] = sum0[4]; + p0f[4] = sum0[1]; + p0f[5] = sum0[5]; + p0f[8] = sum0[2]; + p0f[9] = sum0[6]; + p0f[12] = sum0[3]; + p0f[13] = sum0[7]; + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep, _f1); + } + p0f += out_hstep * 2; } - if (out_elempack == 1) + else { - _mm_storel_epi64((__m128i*)p0, _bf01); - _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf01)); + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + p0f += 8; + } + if (out_elempack == 1) + { + float sum0[8]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum0 + 4, _f1); + + p0f[0] = sum0[0]; + p0f[1] = sum0[4]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep + 1] = sum0[5]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 2 + 1] = sum0[6]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 3 + 1] = sum0[7]; + p0f += 2; + } } - p0 += out_hstep * 2; } else { - if (out_elempack == 4) + __m128i _bf01 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf01); - _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf01)); - p0 += 8; + if (out_elempack == 4) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf01); + + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[4] = sum0[1]; + p0[5] = sum0[5]; + p0[8] = sum0[2]; + p0[9] = sum0[6]; + p0[12] = sum0[3]; + p0[13] = sum0[7]; + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf01); + _mm_storeh_pd((double*)(p0 + out_hstep), _mm_castsi128_pd(_bf01)); + } + p0 += out_hstep * 2; } - if (out_elempack == 1) + else { - unsigned short sum0[8]; - _mm_storeu_si128((__m128i*)sum0, _bf01); + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf01); + _mm_storeh_pd((double*)(p0 + 4), _mm_castsi128_pd(_bf01)); + p0 += 8; + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf01); - p0[0] = sum0[0]; - p0[1] = sum0[4]; - p0[out_hstep] = sum0[1]; - p0[out_hstep + 1] = sum0[5]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 2 + 1] = sum0[6]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 3 + 1] = sum0[7]; - p0 += 2; + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum0[5]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum0[6]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum0[7]; + p0 += 2; + } } } } @@ -7795,58 +9147,104 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f = _mm_mul_ps(_f, _mm_set1_ps(alpha)); } - __m128i _bf = float2bfloat_sse(_f); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 4) + if (output_transpose) { - unsigned short sum0[4]; - _mm_storel_epi64((__m128i*)sum0, _bf); + if (out_elempack == 4) + { + float sum0[4]; + _mm_storeu_ps(sum0, _f); - p0[0] = sum0[0]; - p0[4] = sum0[1]; - p0[4 * 2] = sum0[2]; - p0[4 * 3] = sum0[3]; + p0f[0] = sum0[0]; + p0f[4] = sum0[1]; + p0f[4 * 2] = sum0[2]; + p0f[4 * 3] = sum0[3]; + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0f, _f); + } + p0f += out_hstep; } - if (out_elempack == 1) + else { - _mm_storel_epi64((__m128i*)p0, _bf); + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f); + p0f += 4; + } + if (out_elempack == 1) + { + float sum0[4]; + _mm_storeu_ps(sum0, _f); + + p0f[0] = sum0[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f++; + } } - p0 += out_hstep; } else { - if (out_elempack == 4) + __m128i _bf = float2bfloat_sse(_f); + + if (output_transpose) { - _mm_storel_epi64((__m128i*)p0, _bf); - p0 += 4; + if (out_elempack == 4) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf); + + p0[0] = sum0[0]; + p0[4] = sum0[1]; + p0[4 * 2] = sum0[2]; + p0[4 * 3] = sum0[3]; + } + if (out_elempack == 1) + { + _mm_storel_epi64((__m128i*)p0, _bf); + } + p0 += out_hstep; } - if (out_elempack == 1) + else { - unsigned short sum0[4]; - _mm_storel_epi64((__m128i*)sum0, _bf); + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf); + p0 += 4; + } + if (out_elempack == 1) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf); - p0[0] = sum0[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 3] = sum0[3]; - p0++; + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0++; + } } - } + } // else output_elemtype } } #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) { unsigned short* p0; + float* p0f; if (output_transpose) { p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + p0f = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; } else { p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + p0f = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; } float c0 = 0.f; @@ -7962,81 +9360,126 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm512_mul_ps(_f1, _alpha); } - __m256i _bf0 = float2bfloat_avx512(_f0); - __m256i _bf1 = float2bfloat_avx512(_f1); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 16) + if (output_transpose) { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); - } - if (out_elempack == 8) - { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); - } - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); - _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); - } - if (out_elempack == 1) - { - unsigned short sum0[16]; - unsigned short sum1[16]; - _mm256_storeu_si256((__m256i*)sum0, _bf0); - _mm256_storeu_si256((__m256i*)sum1, _bf1); - - p0[0] = sum0[0]; - p0[1] = sum1[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep + 1] = sum1[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 2 + 1] = sum1[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 3 + 1] = sum1[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 4 + 1] = sum1[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 5 + 1] = sum1[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 6 + 1] = sum1[6]; - p0[out_hstep * 7] = sum0[7]; - p0[out_hstep * 7 + 1] = sum1[7]; - p0[out_hstep * 8] = sum0[8]; - p0[out_hstep * 8 + 1] = sum1[8]; - p0[out_hstep * 9] = sum0[9]; - p0[out_hstep * 9 + 1] = sum1[9]; - p0[out_hstep * 10] = sum0[10]; - p0[out_hstep * 10 + 1] = sum1[10]; - p0[out_hstep * 11] = sum0[11]; - p0[out_hstep * 11 + 1] = sum1[11]; - p0[out_hstep * 12] = sum0[12]; - p0[out_hstep * 12 + 1] = sum1[12]; - p0[out_hstep * 13] = sum0[13]; - p0[out_hstep * 13 + 1] = sum1[13]; - p0[out_hstep * 14] = sum0[14]; - p0[out_hstep * 14 + 1] = sum1[14]; - p0[out_hstep * 15] = sum0[15]; - p0[out_hstep * 15 + 1] = sum1[15]; - } - p0 += out_hstep * 16; + if (out_elempack == 16) + { + _mm512_store_ps(p0f, _f0); + _mm512_store_ps(p0f + 16, _f1); + } + if (out_elempack == 8) + { + _mm256_store_ps(p0f, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0f + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0f + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + } + if (out_elempack == 4) + { + _mm_store_ps(p0f, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0f + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0f + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0f + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0f + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0f + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0f + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0f + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_f1, 3)); + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0f, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0f + 1, _vindex, _f1, sizeof(float)); + } + p0f += out_hstep * 16; + } + else + { + _mm512_storeu_ps(p0f, _f0); + _mm512_storeu_ps(p0f + out_hstep, _f1); + p0f += 16; + } } else { - _mm256_storeu_si256((__m256i*)p0, _bf0); - _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); - p0 += 16; + __m256i _bf0 = float2bfloat_avx512(_f0); + __m256i _bf1 = float2bfloat_avx512(_f1); + + if (output_transpose) + { + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + 16), _bf1); + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + 8), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8 + 8), _mm256_extractf128_si256(_bf1, 1)); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storel_epi64((__m128i*)(p0 + 4), _mm256_extractf128_si256(_bf1, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8 + 4), _mm256_extractf128_si256(_bf1, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm_storeh_pd((double*)(p0 + out_hstep * 12 + 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf1, 1))); + } + if (out_elempack == 1) + { + unsigned short sum0[16]; + unsigned short sum1[16]; + _mm256_storeu_si256((__m256i*)sum0, _bf0); + _mm256_storeu_si256((__m256i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + p0[out_hstep * 8] = sum0[8]; + p0[out_hstep * 8 + 1] = sum1[8]; + p0[out_hstep * 9] = sum0[9]; + p0[out_hstep * 9 + 1] = sum1[9]; + p0[out_hstep * 10] = sum0[10]; + p0[out_hstep * 10 + 1] = sum1[10]; + p0[out_hstep * 11] = sum0[11]; + p0[out_hstep * 11 + 1] = sum1[11]; + p0[out_hstep * 12] = sum0[12]; + p0[out_hstep * 12 + 1] = sum1[12]; + p0[out_hstep * 13] = sum0[13]; + p0[out_hstep * 13 + 1] = sum1[13]; + p0[out_hstep * 14] = sum0[14]; + p0[out_hstep * 14 + 1] = sum1[14]; + p0[out_hstep * 15] = sum0[15]; + p0[out_hstep * 15 + 1] = sum1[15]; + } + p0 += out_hstep * 16; + } + else + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm256_storeu_si256((__m256i*)(p0 + out_hstep), _bf1); + p0 += 16; + } } } #endif // __AVX512F__ @@ -8131,54 +9574,112 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f3 = _mm_mul_ps(_f3, _alpha); } - __m128i _bf0 = float2bfloat_sse(_f0, _f1); - __m128i _bf1 = float2bfloat_sse(_f2, _f3); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_elempack == 8) + if (output_transpose) { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + if (out_elempack == 8) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + _mm_storeu_ps(p0f + 8, _f2); + _mm_storeu_ps(p0f + 8 + 4, _f3); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f2); + _mm_storeu_ps(p0f + out_hstep * 4, _f1); + _mm_storeu_ps(p0f + out_hstep * 4 + 4, _f3); + } + if (out_elempack == 1) + { + float sum0[8]; + float sum1[8]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum0 + 4, _f1); + _mm_storeu_ps(sum1, _f2); + _mm_storeu_ps(sum1 + 4, _f3); + + p0f[0] = sum0[0]; + p0f[1] = sum1[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep + 1] = sum1[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 2 + 1] = sum1[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 3 + 1] = sum1[3]; + p0f[out_hstep * 4] = sum0[4]; + p0f[out_hstep * 4 + 1] = sum1[4]; + p0f[out_hstep * 5] = sum0[5]; + p0f[out_hstep * 5 + 1] = sum1[5]; + p0f[out_hstep * 6] = sum0[6]; + p0f[out_hstep * 6 + 1] = sum1[6]; + p0f[out_hstep * 7] = sum0[7]; + p0f[out_hstep * 7 + 1] = sum1[7]; + } + p0f += out_hstep * 8; } - if (out_elempack == 4) + else { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); - } - if (out_elempack == 1) - { - unsigned short sum0[8]; - unsigned short sum1[8]; - _mm_storeu_si128((__m128i*)sum0, _bf0); - _mm_storeu_si128((__m128i*)sum1, _bf1); - - p0[0] = sum0[0]; - p0[1] = sum1[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep + 1] = sum1[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 2 + 1] = sum1[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 3 + 1] = sum1[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 4 + 1] = sum1[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 5 + 1] = sum1[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 6 + 1] = sum1[6]; - p0[out_hstep * 7] = sum0[7]; - p0[out_hstep * 7 + 1] = sum1[7]; - } - p0 += out_hstep * 8; + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + _mm_storeu_ps(p0f + out_hstep, _f2); + _mm_storeu_ps(p0f + out_hstep + 4, _f3); + p0f += 8; + } } else { - _mm_storeu_si128((__m128i*)p0, _bf0); - _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); - p0 += 8; + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + __m128i _bf1 = float2bfloat_sse(_f2, _f3); + + if (output_transpose) + { + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + 8), _bf1); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + 4), _bf1); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_bf0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4 + 4), _mm_castsi128_pd(_bf1)); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + unsigned short sum1[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + _mm_storeu_si128((__m128i*)sum1, _bf1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; + } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_si128((__m128i*)p0, _bf0); + _mm_storeu_si128((__m128i*)(p0 + out_hstep), _bf1); + p0 += 8; + } } } #endif // defined(__x86_64__) || defined(_M_X64) @@ -8243,55 +9744,111 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm_mul_ps(_f1, _alpha); } - __m128i _bf0 = float2bfloat_sse(_f0, _f1); - - if (output_transpose) + if (output_elemtype == 1) { + if (output_transpose) + { #if !(defined(__x86_64__) || defined(_M_X64)) #if __AVX__ #if __AVX512F__ - if (out_elempack == 16) - { - const int jj_m16 = jj % 16; - unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; - _mm_storel_epi64((__m128i*)p1, _bf0); - _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); - } + if (out_elempack == 16) + { + const int jj_m16 = jj % 16; + float* p1f = p0f - out_hstep * jj_m16 + jj_m16; + _mm_storeu_ps(p1f, _f0); + _mm_storeu_ps(p1f + 16, _f1); + } #endif // __AVX512F__ - if (out_elempack == 8) - { - const int jj_m8 = jj % 8; - unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; - _mm_storel_epi64((__m128i*)p1, _bf0); - _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); - } + if (out_elempack == 8) + { + const int jj_m8 = jj % 8; + float* p1f = p0f - out_hstep * jj_m8 + jj_m8; + _mm_storeu_ps(p1f, _f0); + _mm_storeu_ps(p1f + 8, _f1); + } #endif // __AVX__ #endif // !(defined(__x86_64__) || defined(_M_X64)) - if (out_elempack == 4) - { - _mm_storeu_si128((__m128i*)p0, _bf0); + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + } + if (out_elempack == 1) + { + float sum0[8]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum0 + 4, _f1); + + p0f[0] = sum0[0]; + p0f[1] = sum0[4]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep + 1] = sum0[5]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 2 + 1] = sum0[6]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 3 + 1] = sum0[7]; + } + p0f += out_hstep * 4; } - if (out_elempack == 1) + else { - unsigned short sum0[8]; - _mm_storeu_si128((__m128i*)sum0, _bf0); - - p0[0] = sum0[0]; - p0[1] = sum0[4]; - p0[out_hstep] = sum0[1]; - p0[out_hstep + 1] = sum0[5]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 2 + 1] = sum0[6]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 3 + 1] = sum0[7]; + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep, _f1); + p0f += 4; } - p0 += out_hstep * 4; } else { - _mm_storel_epi64((__m128i*)p0, _bf0); - _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm_srli_si128(_bf0, 8)); - p0 += 4; + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + const int jj_m16 = jj % 16; + unsigned short* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 16), _mm_castsi128_pd(_bf0)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + const int jj_m8 = jj % 8; + unsigned short* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storel_epi64((__m128i*)p1, _bf0); + _mm_storeh_pd((double*)(p1 + 8), _mm_castsi128_pd(_bf0)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[1] = sum0[4]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum0[5]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum0[6]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum0[7]; + } + p0 += out_hstep * 4; + } + else + { + _mm_storel_epi64((__m128i*)p0, _bf0); + _mm_storel_epi64((__m128i*)(p0 + out_hstep), _mm_srli_si128(_bf0, 8)); + p0 += 4; + } } } #endif // __SSE2__ @@ -8343,26 +9900,48 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& f10 *= alpha; f11 *= alpha; - unsigned short bf00 = float32_to_bfloat16(f00); - unsigned short bf01 = float32_to_bfloat16(f01); - unsigned short bf10 = float32_to_bfloat16(f10); - unsigned short bf11 = float32_to_bfloat16(f11); - - if (output_transpose) + if (output_elemtype == 1) { - p0[0] = bf00; - p0[1] = bf10; - p0[out_hstep] = bf01; - p0[out_hstep + 1] = bf11; - p0 += out_hstep * 2; + if (output_transpose) + { + p0f[0] = f00; + p0f[1] = f10; + p0f[out_hstep] = f01; + p0f[out_hstep + 1] = f11; + p0f += out_hstep * 2; + } + else + { + p0f[0] = f00; + p0f[1] = f01; + p0f[out_hstep] = f10; + p0f[out_hstep + 1] = f11; + p0f += 2; + } } else { - p0[0] = bf00; - p0[1] = bf01; - p0[out_hstep] = bf10; - p0[out_hstep + 1] = bf11; - p0 += 2; + unsigned short bf00 = float32_to_bfloat16(f00); + unsigned short bf01 = float32_to_bfloat16(f01); + unsigned short bf10 = float32_to_bfloat16(f10); + unsigned short bf11 = float32_to_bfloat16(f11); + + if (output_transpose) + { + p0[0] = bf00; + p0[1] = bf10; + p0[out_hstep] = bf01; + p0[out_hstep + 1] = bf11; + p0 += out_hstep * 2; + } + else + { + p0[0] = bf00; + p0[1] = bf01; + p0[out_hstep] = bf10; + p0[out_hstep + 1] = bf11; + p0 += 2; + } } } for (; jj < max_jj; jj++) @@ -8401,33 +9980,54 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& f0 *= alpha; f1 *= alpha; - unsigned short bf0 = float32_to_bfloat16(f0); - unsigned short bf1 = float32_to_bfloat16(f1); - - if (output_transpose) + if (output_elemtype == 1) { - p0[0] = bf0; - p0[1] = bf1; - p0 += out_hstep; + if (output_transpose) + { + p0f[0] = f0; + p0f[1] = f1; + p0f += out_hstep; + } + else + { + p0f[0] = f0; + p0f[out_hstep] = f1; + p0f++; + } } else { - p0[0] = bf0; - p0[out_hstep] = bf1; - p0++; + unsigned short bf0 = float32_to_bfloat16(f0); + unsigned short bf1 = float32_to_bfloat16(f1); + + if (output_transpose) + { + p0[0] = bf0; + p0[1] = bf1; + p0 += out_hstep; + } + else + { + p0[0] = bf0; + p0[out_hstep] = bf1; + p0++; + } } } } for (; ii < max_ii; ii++) { unsigned short* p0; + float* p0f; if (output_transpose) { p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + p0f = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; } else { p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + p0f = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; } float c0 = 0.f; @@ -8500,61 +10100,121 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); } - __m256i _bf0 = float2bfloat_avx512(_f0); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_hstep == 1) - { - _mm256_storeu_si256((__m256i*)p0, _bf0); - } - else + if (output_transpose) { - if (out_elempack == 16) + if (out_hstep == 1) { - _mm256_storeu_si256((__m256i*)p0, _bf0); + _mm512_storeu_ps(p0f, _f0); } - if (out_elempack == 8) + else { - _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + if (out_elempack == 16) + { + _mm512_storeu_ps(p0f, _f0); + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0f, _mm512_castps512_ps256(_f0)); + _mm256_storeu_ps(p0f + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _mm512_castps512_ps128(_f0)); + _mm_storeu_ps(p0f + out_hstep * 4, _mm256_extractf128_ps(_mm512_castps512_ps256(_f0), 1)); + _mm_storeu_ps(p0f + out_hstep * 8, _mm256_castps256_ps128(_mm512_extractf32x8_ps(_f0, 1))); + _mm_storeu_ps(p0f + out_hstep * 12, _mm256_extractf128_ps(_mm512_extractf32x8_ps(_f0, 1), 1)); + } + if (out_elempack == 1) + { + float sum0[16]; + _mm512_storeu_ps(sum0, _f0); + + p0f[0] = sum0[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 4] = sum0[4]; + p0f[out_hstep * 5] = sum0[5]; + p0f[out_hstep * 6] = sum0[6]; + p0f[out_hstep * 7] = sum0[7]; + p0f[out_hstep * 8] = sum0[8]; + p0f[out_hstep * 9] = sum0[9]; + p0f[out_hstep * 10] = sum0[10]; + p0f[out_hstep * 11] = sum0[11]; + p0f[out_hstep * 12] = sum0[12]; + p0f[out_hstep * 13] = sum0[13]; + p0f[out_hstep * 14] = sum0[14]; + p0f[out_hstep * 15] = sum0[15]; + } } - if (out_elempack == 4) + p0f += out_hstep * 16; + } + else + { + _mm512_storeu_ps(p0f, _f0); + p0f += 16; + } + } + else + { + __m256i _bf0 = float2bfloat_avx512(_f0); + + if (output_transpose) + { + if (out_hstep == 1) { - _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); - _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); - _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + _mm256_storeu_si256((__m256i*)p0, _bf0); } - if (out_elempack == 1) + else { - unsigned short sum0[16]; - _mm256_storeu_si256((__m256i*)sum0, _bf0); - - p0[0] = sum0[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 7] = sum0[7]; - p0[out_hstep * 8] = sum0[8]; - p0[out_hstep * 9] = sum0[9]; - p0[out_hstep * 10] = sum0[10]; - p0[out_hstep * 11] = sum0[11]; - p0[out_hstep * 12] = sum0[12]; - p0[out_hstep * 13] = sum0[13]; - p0[out_hstep * 14] = sum0[14]; - p0[out_hstep * 15] = sum0[15]; + if (out_elempack == 16) + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + } + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeu_si128((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _mm256_extractf128_si256(_bf0, 0)); + _mm_storeh_pd((double*)(p0 + out_hstep * 4), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 0))); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 8), _mm256_extractf128_si256(_bf0, 1)); + _mm_storeh_pd((double*)(p0 + out_hstep * 12), _mm_castsi128_pd(_mm256_extractf128_si256(_bf0, 1))); + } + if (out_elempack == 1) + { + unsigned short sum0[16]; + _mm256_storeu_si256((__m256i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 8] = sum0[8]; + p0[out_hstep * 9] = sum0[9]; + p0[out_hstep * 10] = sum0[10]; + p0[out_hstep * 11] = sum0[11]; + p0[out_hstep * 12] = sum0[12]; + p0[out_hstep * 13] = sum0[13]; + p0[out_hstep * 14] = sum0[14]; + p0[out_hstep * 15] = sum0[15]; + } } + p0 += out_hstep * 16; + } + else + { + _mm256_storeu_si256((__m256i*)p0, _bf0); + p0 += 16; } - p0 += out_hstep * 16; - } - else - { - _mm256_storeu_si256((__m256i*)p0, _bf0); - p0 += 16; } } #endif // __AVX512F__ @@ -8589,48 +10249,99 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f1 = _mm_mul_ps(_f1, _alpha); } - __m128i _bf0 = float2bfloat_sse(_f0, _f1); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_hstep == 1) + if (output_transpose) { - _mm_storeu_si128((__m128i*)p0, _bf0); + if (out_hstep == 1) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + } + else + { +#if __AVX__ + if (out_elempack == 8) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + } +#endif // __AVX__ + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + out_hstep * 4, _f1); + } + if (out_elempack == 1) + { + float sum0[8]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum0 + 4, _f1); + + p0f[0] = sum0[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 3] = sum0[3]; + p0f[out_hstep * 4] = sum0[4]; + p0f[out_hstep * 5] = sum0[5]; + p0f[out_hstep * 6] = sum0[6]; + p0f[out_hstep * 7] = sum0[7]; + } + } + p0f += out_hstep * 8; } else { -#if __AVX__ - if (out_elempack == 8) + _mm_storeu_ps(p0f, _f0); + _mm_storeu_ps(p0f + 4, _f1); + p0f += 8; + } + } + else + { + __m128i _bf0 = float2bfloat_sse(_f0, _f1); + + if (output_transpose) + { + if (out_hstep == 1) { _mm_storeu_si128((__m128i*)p0, _bf0); } -#endif // __AVX__ - if (out_elempack == 4) - { - _mm_storel_epi64((__m128i*)p0, float2bfloat_sse(_f0)); - _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), float2bfloat_sse(_f1)); - } - if (out_elempack == 1) + else { - unsigned short sum0[8]; - _mm_storeu_si128((__m128i*)sum0, _bf0); - - p0[0] = sum0[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 3] = sum0[3]; - p0[out_hstep * 4] = sum0[4]; - p0[out_hstep * 5] = sum0[5]; - p0[out_hstep * 6] = sum0[6]; - p0[out_hstep * 7] = sum0[7]; +#if __AVX__ + if (out_elempack == 8) + { + _mm_storeu_si128((__m128i*)p0, _bf0); + } +#endif // __AVX__ + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, float2bfloat_sse(_f0)); + _mm_storel_epi64((__m128i*)(p0 + out_hstep * 4), float2bfloat_sse(_f1)); + } + if (out_elempack == 1) + { + unsigned short sum0[8]; + _mm_storeu_si128((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; + } } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_si128((__m128i*)p0, _bf0); + p0 += 8; } - p0 += out_hstep * 8; - } - else - { - _mm_storeu_si128((__m128i*)p0, _bf0); - p0 += 8; } } #endif // defined(__x86_64__) || defined(_M_X64) @@ -8656,51 +10367,101 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& _f0 = _mm_mul_ps(_f0, _mm_set1_ps(alpha)); - __m128i _bf0 = float2bfloat_sse(_f0); - - if (output_transpose) + if (output_elemtype == 1) { - if (out_hstep == 1) - { - _mm_storel_epi64((__m128i*)p0, _bf0); - } - else + if (output_transpose) { + if (out_hstep == 1) + { + _mm_storeu_ps(p0f, _f0); + } + else + { #if !(defined(__x86_64__) || defined(_M_X64)) #if __AVX__ #if __AVX512F__ - if (out_elempack == 16) - { - _mm_storel_epi64((__m128i*)(p0 - (jj % 16) / 4 * out_hstep * 4 + (jj % 16) / 4 * 4), _bf0); - } + if (out_elempack == 16) + { + _mm_storeu_ps(p0f - (jj % 16) / 4 * out_hstep * 4 + (jj % 16) / 4 * 4, _f0); + } #endif // __AVX512F__ - if (out_elempack == 8) - { - _mm_storel_epi64((__m128i*)(p0 - (jj % 8) / 4 * out_hstep * 4 + (jj % 8) / 4 * 4), _bf0); - } + if (out_elempack == 8) + { + _mm_storeu_ps(p0f - (jj % 8) / 4 * out_hstep * 4 + (jj % 8) / 4 * 4, _f0); + } #endif // __AVX__ #endif // !(defined(__x86_64__) || defined(_M_X64)) - if (out_elempack == 4) + if (out_elempack == 4) + { + _mm_storeu_ps(p0f, _f0); + } + if (out_elempack == 1) + { + float sum0[4]; + _mm_storeu_ps(sum0, _f0); + + p0f[0] = sum0[0]; + p0f[out_hstep] = sum0[1]; + p0f[out_hstep * 2] = sum0[2]; + p0f[out_hstep * 3] = sum0[3]; + } + } + p0f += out_hstep * 4; + } + else + { + _mm_storeu_ps(p0f, _f0); + p0f += 4; + } + } + else + { + __m128i _bf0 = float2bfloat_sse(_f0); + + if (output_transpose) + { + if (out_hstep == 1) { _mm_storel_epi64((__m128i*)p0, _bf0); } - if (out_elempack == 1) + else { - unsigned short sum0[4]; - _mm_storel_epi64((__m128i*)sum0, _bf0); - - p0[0] = sum0[0]; - p0[out_hstep] = sum0[1]; - p0[out_hstep * 2] = sum0[2]; - p0[out_hstep * 3] = sum0[3]; +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + _mm_storel_epi64((__m128i*)(p0 - (jj % 16) / 4 * out_hstep * 4 + (jj % 16) / 4 * 4), _bf0); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + _mm_storel_epi64((__m128i*)(p0 - (jj % 8) / 4 * out_hstep * 4 + (jj % 8) / 4 * 4), _bf0); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)p0, _bf0); + } + if (out_elempack == 1) + { + unsigned short sum0[4]; + _mm_storel_epi64((__m128i*)sum0, _bf0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + } } + p0 += out_hstep * 4; + } + else + { + _mm_storel_epi64((__m128i*)p0, _bf0); + p0 += 4; } - p0 += out_hstep * 4; - } - else - { - _mm_storel_epi64((__m128i*)p0, _bf0); - p0 += 4; } } #endif // __SSE2__ @@ -8729,20 +10490,38 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& f0 *= alpha; f1 *= alpha; - unsigned short bf0 = float32_to_bfloat16(f0); - unsigned short bf1 = float32_to_bfloat16(f1); - - if (output_transpose) + if (output_elemtype == 1) { - p0[0] = bf0; - p0[out_hstep] = bf1; - p0 += out_hstep * 2; + if (output_transpose) + { + p0f[0] = f0; + p0f[out_hstep] = f1; + p0f += out_hstep * 2; + } + else + { + p0f[0] = f0; + p0f[1] = f1; + p0f += 2; + } } else { - p0[0] = bf0; - p0[1] = bf1; - p0 += 2; + unsigned short bf0 = float32_to_bfloat16(f0); + unsigned short bf1 = float32_to_bfloat16(f1); + + if (output_transpose) + { + p0[0] = bf0; + p0[out_hstep] = bf1; + p0 += out_hstep * 2; + } + else + { + p0[0] = bf0; + p0[1] = bf1; + p0 += 2; + } } } for (; jj < max_jj; jj++) @@ -8766,15 +10545,31 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat& f0 *= alpha; - p0[0] = float32_to_bfloat16(f0); - - if (output_transpose) + if (output_elemtype == 1) { - p0 += out_hstep; + p0f[0] = f0; + + if (output_transpose) + { + p0f += out_hstep; + } + else + { + p0f++; + } } else { - p0++; + p0[0] = float32_to_bfloat16(f0); + + if (output_transpose) + { + p0 += out_hstep; + } + else + { + p0++; + } } } } diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index a5acea649f9..14009d45cae 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -8518,7 +8518,7 @@ int Gemm_x86::create_pipeline_bf16s(const Option& opt) return 0; } -static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt) { const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; @@ -8587,14 +8587,14 @@ static int gemm_AT_x86_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype); } } return 0; } -static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt) { const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; @@ -8652,14 +8652,14 @@ static int gemm_BT_x86_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype); } } return 0; } -static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt) { int TILE_M, TILE_N, TILE_K; get_optimal_tile_mnk_bf16(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); @@ -8693,14 +8693,14 @@ static int gemm_AT_BT_x86_bf16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype); } } return 0; } -static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, int output_elemtype, const Option& opt) { const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; @@ -8788,7 +8788,7 @@ static int gemm_x86_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); } - unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose); + unpack_output_tile_fp32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose, output_elemtype); } } @@ -8908,7 +8908,7 @@ int Gemm_x86::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vectorload_model(ModelBinFromMatArray(weights)); - o_gemm->create_pipeline(opt); + Option opt_fp32 = opt; + opt_fp32.use_bf16_packed = false; + opt_fp32.use_bf16_storage = false; + o_gemm->create_pipeline(opt_fp32); if (opt.lightmode) { @@ -203,12 +211,15 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt) pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack + pd.set(13, 1); // output_elemtype = fp32 #if NCNN_INT8 pd.set(18, int8_scale_term); #endif qk_gemm->load_param(pd); qk_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; + opt1.use_bf16_packed = false; + opt1.use_bf16_storage = false; opt1.num_threads = 1; qk_gemm->create_pipeline(opt1); } @@ -227,6 +238,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt) pd.set(10, -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack + pd.set(13, 1); // output_elemtype = fp32 pd.set(14, 1); // output_transpose #if NCNN_INT8 pd.set(18, int8_scale_term); @@ -234,6 +246,8 @@ int MultiHeadAttention_x86::create_pipeline(const Option& _opt) qkv_gemm->load_param(pd); qkv_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; + opt1.use_bf16_packed = false; + opt1.use_bf16_storage = false; opt1.num_threads = 1; qkv_gemm->create_pipeline(opt1); } @@ -488,6 +502,17 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v return retv; } + Mat v_affine_fp32 = v_affine; +#if NCNN_BF16 + if (opt.use_bf16_storage && v_affine.elembits() == 16) + { + // qkv_gemm need fp32 inputs + cast_bfloat16_to_float32(v_affine, v_affine_fp32, opt); + if (v_affine_fp32.empty()) + return -100; + } +#endif + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt.blob_allocator); if (qkv_cross.empty()) return -100; @@ -499,7 +524,7 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v { std::vector qkv_bottom_blobs(2); qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); - qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); + qkv_bottom_blobs[1] = v_affine_fp32.row_range(i * embed_dim_per_head, embed_dim_per_head); std::vector qkv_top_blobs(1); qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head); Option opt1 = opt; @@ -512,6 +537,8 @@ int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::v return retqkvs[i]; } + v_affine_fp32.release(); + if (!kv_cache) { v_affine.release(); diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e8eb8de6f17..a319913c3f9 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -9,6 +9,10 @@ namespace ncnn { SDPA_x86::SDPA_x86() { +#if NCNN_BF16 + support_bf16_storage = true; +#endif + qk_gemm = 0; qkv_gemm = 0; qk_softmax = 0; @@ -20,6 +24,7 @@ int SDPA_x86::create_pipeline(const Option& _opt) if (int8_scale_term) { opt.use_packing_layout = false; // TODO enable packing + support_bf16_storage = false; } { @@ -51,6 +56,7 @@ int SDPA_x86::create_pipeline(const Option& _opt) pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack + pd.set(13, 1); // output_elemtype = fp32 #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -78,6 +84,7 @@ int SDPA_x86::create_pipeline(const Option& _opt) pd.set(10, -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack + pd.set(13, 1); // output_elemtype = fp32 pd.set(14, 0); // output_transpose #if NCNN_INT8 pd.set(18, int8_scale_term); @@ -148,10 +155,12 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const int past_seqlen = kv_cache ? past_key.h : 0; const int dst_seqlen = past_seqlen + cur_seqlen; + const size_t elemsize = query.elemsize; + Mat key; if (past_seqlen > 0) { - key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (key.empty()) return -100; @@ -162,8 +171,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat cur_key_head = cur_key.channel(q); Mat key_head = key.channel(q); - memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); - memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); + memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * elemsize); + memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * elemsize); } } else @@ -174,7 +183,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat value; if (past_seqlen > 0) { - value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); + value.create(out_embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); if (value.empty()) return -100; @@ -185,8 +194,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat cur_value_head = cur_value.channel(q); Mat value_head = value.channel(q); - memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); - memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); + memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * elemsize); + memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * elemsize); } } else @@ -194,11 +203,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to value = cur_value; } - Mat& top_blob = top_blobs[0]; - top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - const int num_heads_per_group = num_heads / num_group; Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); @@ -229,6 +233,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack + pd.set(13, 1); // output_elemtype = fp32 #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -290,6 +295,22 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (retqk != 0) return retqk; + Mat value_fp32 = value; +#if NCNN_BF16 + if (opt.use_bf16_storage && value.elembits() == 16) + { + // qkv_gemm need fp32 inputs + cast_bfloat16_to_float32(value, value_fp32, opt); + if (value_fp32.empty()) + return -100; + } +#endif + + Mat& top_blob = top_blobs[0]; + top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + // 3. Attn * V std::vector retqkvs(num_heads); @@ -297,8 +318,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int i = 0; i < num_heads; i++) { std::vector qkv_bottom_blobs(2); - qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] - qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] + qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] + qkv_bottom_blobs[1] = value_fp32.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] std::vector qkv_top_blobs(1); qkv_top_blobs[0] = top_blob.channel(i); // Output @@ -314,6 +335,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return retqkvs[i]; } + value_fp32.release(); + if (kv_cache) { top_blobs[1] = key; diff --git a/tests/test_gemm_2e.cpp b/tests/test_gemm_2e.cpp index 01ec5c691a7..693e6c5b8c7 100644 --- a/tests/test_gemm_2e.cpp +++ b/tests/test_gemm_2e.cpp @@ -14,7 +14,11 @@ int main() {44, 19, 7}, {47, 35, 48}, {47, 48, 47}, - {48, 35, 47} + {48, 35, 47}, + {32, 24, 5}, + {20, 24, 5}, + {32, 20, 5}, + {24, 20, 5}, }; int mnk_count = sizeof(mnk) / sizeof(int) / 3; diff --git a/tests/test_gemm_5.cpp b/tests/test_gemm_5.cpp new file mode 100644 index 00000000000..e3ee89aa655 --- /dev/null +++ b/tests/test_gemm_5.cpp @@ -0,0 +1,327 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +static int test_gemm(int M, int N, int K, int output_transpose, int output_N1M = 0) +{ + ncnn::ParamDict pd; + pd.set(0, 1.f); // alpha + pd.set(1, 1.f); // beta + pd.set(2, 0); // transA + pd.set(3, 0); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(13, 1); // output_elemtype = fp32 + pd.set(14, output_transpose); + + std::vector weights; + + std::vector a; + a.push_back(output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)); + a.push_back(output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i]); + } + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm failed M=%d N=%d K=%d output_transpose=%d output_N1M=%d\n", M, N, K, output_transpose, output_N1M); + } + + return ret; +} + +static int test_gemm_0(int M, int N, int K) +{ + return 0 + || test_gemm(M, N, K, 0, 0) + || test_gemm(M, N, K, 0, 1) + || test_gemm(M, N, K, 1, 0) + || test_gemm(M, N, K, 1, 1); +} + +// Test with forced output_elempack, fp32 paths only. +// Uses test_layer_opt to avoid running bf16 paths where this elempack may be invalid. +static int test_gemm_ep(int M, int N, int K, int output_elempack, int output_transpose, int output_N1M = 0) +{ + ncnn::ParamDict pd; + pd.set(0, 1.f); // alpha + pd.set(1, 1.f); // beta + pd.set(2, 0); // transA + pd.set(3, 0); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(12, output_elempack); + pd.set(13, 1); // output_elemtype = fp32 + pd.set(14, output_transpose); + + std::vector weights; + + std::vector a; + a.push_back(output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)); + a.push_back(output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i]); + } + + // Only run fp32-safe option combos (use_bf16_storage=0) + // pack fp16p fp16s fp16a bf16 + const int options[][2] = { + {0, 0}, + {1, 0}, + {1, 1}, + }; + + for (int i = 0; i < 3; i++) + { + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = options[i][0]; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_packed = options[i][1]; + opt.use_bf16_storage = options[i][1]; + + int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, 0.001, TEST_LAYER_DISABLE_GPU_TESTING); + if (ret != 0) + { + fprintf(stderr, "test_gemm_ep failed M=%d N=%d K=%d output_elempack=%d output_transpose=%d output_N1M=%d\n", M, N, K, output_elempack, output_transpose, output_N1M); + return ret; + } + } + + return 0; +} + +// Test with forced output_elempack, bf16 paths only. +// Uses test_layer_opt to avoid running fp32 paths where this elempack may be invalid. +static int test_gemm_ep_bf16(int M, int N, int K, int output_elempack, int output_transpose) +{ + ncnn::ParamDict pd; + pd.set(0, 1.f); // alpha + pd.set(1, 1.f); // beta + pd.set(2, 0); // transA + pd.set(3, 0); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(12, output_elempack); + // pd.set(13, ...); // output_elemtype = 0 (bf16, default) + pd.set(14, output_transpose); + + std::vector weights; + + std::vector a; + a.push_back(ncnn::Mat(K, M)); + a.push_back(ncnn::Mat(N, K)); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i]); + } + + // Only run bf16 option combos + // pack bf16 + const int options[][2] = { + {0, 1}, + {1, 1}, + }; + + for (int i = 0; i < 2; i++) + { + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = options[i][0]; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_packed = options[i][1]; + opt.use_bf16_storage = options[i][1]; + + int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, 0.001, TEST_LAYER_DISABLE_GPU_TESTING); + if (ret != 0) + { + fprintf(stderr, "test_gemm_ep_bf16 failed M=%d N=%d K=%d output_elempack=%d output_transpose=%d\n", M, N, K, output_elempack, output_transpose); + return ret; + } + } + + return 0; +} + +static int test_gemm_1(int M, int N, int K, int fp32_min_elempack, int fp32_max_elempack, int bf16_min_elempack, int bf16_max_elempack) +{ + const int elempacks[] = {1, 4, 8, 16}; + + for (int ei = 0; ei < 4; ei++) + { + int ep = elempacks[ei]; + + for (int output_transpose = 0; output_transpose < 2; output_transpose++) + { + int outh = output_transpose ? N : M; + if (outh % ep != 0) + continue; + + // fp32 path + if (ep == 1 || (ep <= fp32_max_elempack && ep % fp32_min_elempack == 0)) + { + for (int output_N1M = 0; output_N1M < 2; output_N1M++) + { + int ret = test_gemm_ep(M, N, K, ep, output_transpose, output_N1M); + if (ret != 0) + return ret; + } + } + + // bf16 path (only when bf16 supports larger elempack than fp32) + if ((ep == 1 || (ep <= bf16_max_elempack && ep % fp32_min_elempack == 0)) && ep > fp32_max_elempack) + { + int ret = test_gemm_ep_bf16(M, N, K, ep, output_transpose); + if (ret != 0) + return ret; + } + } + } + + return 0; +} + +int main() +{ + SRAND(7767517); + + int mnk_scalar[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {3, 3, 3}, + {4, 4, 4}, + {5, 5, 5}, + {7, 7, 7}, + {8, 8, 8}, + {15, 15, 15}, + {16, 16, 16}, + {24, 24, 24}, + {31, 32, 31}, + {32, 32, 32}, + {47, 48, 47}, + {64, 64, 64}, + }; + + for (int i = 0; i < 14; i++) + { + int ret = test_gemm_0(mnk_scalar[i][0], mnk_scalar[i][1], mnk_scalar[i][2]); + if (ret != 0) + return ret; + } + + int fp32_min_elempack = 1; + int fp32_max_elempack = 1; +#if __SSE2__ || __ARM_NEON + fp32_min_elempack = 4; + fp32_max_elempack = 4; +#endif + +#if NCNN_AVX + if (ncnn::cpu_support_x86_avx()) + fp32_max_elempack = 8; +#if NCNN_AVX512 + if (ncnn::cpu_support_x86_avx512()) + fp32_max_elempack = 16; +#endif +#endif + +#if NCNN_RVV || NCNN_XTHEADVECTOR + if (ncnn::cpu_support_riscv_v() || ncnn::cpu_support_riscv_xtheadvector()) + { + fp32_min_elempack = ncnn::cpu_riscv_vlenb() / 4; + fp32_max_elempack = ncnn::cpu_riscv_vlenb() / 4; + } +#endif + + int bf16_min_elempack = fp32_min_elempack; + int bf16_max_elempack = fp32_max_elempack; + + for (int i = 0; i < 14; i++) + { + int ret = test_gemm_1(mnk_scalar[i][0], mnk_scalar[i][1], mnk_scalar[i][2], fp32_min_elempack, fp32_max_elempack, bf16_min_elempack, bf16_max_elempack); + if (ret != 0) + return ret; + } + + // Asymmetric M/N to cover output_transpose paths with various ii/jj + // remainder blocks. In unpack_output_tile_fp32_to_bf16: + // ii iterates M dimension, jj iterates N dimension. + int mnk_asym[][3] = { + {1, 16, 4}, + {2, 16, 4}, + {3, 16, 4}, + {5, 32, 4}, + {1, 8, 4}, + {2, 8, 4}, + {3, 8, 4}, + {5, 8, 4}, + {16, 1, 4}, + {16, 3, 4}, + {16, 5, 4}, + {17, 16, 4}, + {33, 17, 4}, + {8, 4, 4}, + {8, 2, 4}, + {8, 1, 4}, + {4, 2, 4}, + {4, 1, 4}, + {2, 32, 4}, + {16, 8, 4}, + {4, 16, 4}, + {4, 8, 4}, + }; + + int num_asym = sizeof(mnk_asym) / sizeof(mnk_asym[0]); + for (int i = 0; i < num_asym; i++) + { + int ret = test_gemm_1(mnk_asym[i][0], mnk_asym[i][1], mnk_asym[i][2], fp32_min_elempack, fp32_max_elempack, bf16_min_elempack, bf16_max_elempack); + if (ret != 0) + return ret; + } + + if (bf16_max_elempack >= 4 && 4 % bf16_min_elempack == 0) + { + // bf16 output (output_elemtype=0) with out_elempack=4, output_transpose=1 + // to cover the bf16 store paths in unpack_output_tile_fp32_to_bf16 + int ret = 0 + || test_gemm_ep_bf16(4, 16, 4, 4, 1) + || test_gemm_ep_bf16(4, 8, 4, 4, 1) + || test_gemm_ep_bf16(2, 16, 4, 4, 1) + || test_gemm_ep_bf16(1, 16, 4, 4, 1) + || test_gemm_ep_bf16(3, 16, 4, 4, 1) + || test_gemm_ep_bf16(8, 16, 4, 4, 1); + if (ret != 0) + return ret; + } + + return 0; +} diff --git a/tests/testutil.cpp b/tests/testutil.cpp index 878142a2d2c..92803642d29 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -668,6 +668,16 @@ int test_layer_cpu(int typeindex, const ncnn::ParamDict& pd, const std::vectortypeindex == ncnn::LayerType::MultiHeadAttention) + { + fprintf(stderr, "fixme: skip bf16 test for MultiHeadAttention\n"); + delete op; + return 233; + } + } + ncnn::ModelBinFromMatArray mb(weights.data()); op->load_model(mb); @@ -809,21 +819,11 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorvkdev = vkdev; - ncnn::VkWeightAllocator g_weight_vkallocator(vkdev); - ncnn::VkWeightStagingAllocator g_weight_staging_vkallocator(vkdev); - - ncnn::VkAllocator* blob_vkallocator = vkdev->acquire_blob_allocator(); - ncnn::VkAllocator* staging_vkallocator = vkdev->acquire_staging_allocator(); - if (flag & TEST_LAYER_ENABLE_THREADING) opt.num_threads = ncnn::get_physical_big_cpu_count(); else opt.num_threads = 1; - opt.blob_vkallocator = blob_vkallocator; - opt.workspace_vkallocator = blob_vkallocator; - opt.staging_vkallocator = staging_vkallocator; - if (!vkdev->info.support_fp16_packed()) opt.use_fp16_packed = false; if (!vkdev->info.support_fp16_storage()) opt.use_fp16_storage = false; if (!vkdev->info.support_fp16_uniform()) opt.use_fp16_uniform = false; @@ -848,6 +848,7 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectortypeindex == ncnn::LayerType::MultiHeadAttention) { fprintf(stderr, "fixme: skip gpu bf16 test for MultiHeadAttention\n"); + delete op; return 233; } } @@ -931,12 +932,24 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorload_model(mb); + ncnn::VkWeightAllocator g_weight_vkallocator(vkdev); + ncnn::VkWeightStagingAllocator g_weight_staging_vkallocator(vkdev); + + ncnn::VkAllocator* blob_vkallocator = vkdev->acquire_blob_allocator(); + ncnn::VkAllocator* staging_vkallocator = vkdev->acquire_staging_allocator(); + + opt.blob_vkallocator = blob_vkallocator; + opt.workspace_vkallocator = blob_vkallocator; + opt.staging_vkallocator = staging_vkallocator; + op->create_pipeline(opt); if (!op->support_vulkan) { op->destroy_pipeline(opt); delete op; + vkdev->reclaim_blob_allocator(blob_vkallocator); + vkdev->reclaim_staging_allocator(staging_vkallocator); return 233; } @@ -1239,6 +1252,16 @@ int test_layer_cpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorload_param(pd); + if (_opt.use_bf16_packed || _opt.use_bf16_storage) + { + if (op->typeindex == ncnn::LayerType::MultiHeadAttention) + { + fprintf(stderr, "fixme: skip bf16 test for MultiHeadAttention\n"); + delete op; + return 233; + } + } + ncnn::ModelBinFromMatArray mb(weights.data()); op->load_model(mb); @@ -1354,12 +1377,6 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorvkdev = vkdev; - ncnn::VkWeightAllocator g_weight_vkallocator(vkdev); - ncnn::VkWeightStagingAllocator g_weight_staging_vkallocator(vkdev); - - ncnn::VkAllocator* blob_vkallocator = vkdev->acquire_blob_allocator(); - ncnn::VkAllocator* staging_vkallocator = vkdev->acquire_staging_allocator(); - opt.use_vulkan_compute = true; if (flag & TEST_LAYER_ENABLE_THREADING) @@ -1367,10 +1384,6 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorinfo.support_fp16_packed()) opt.use_fp16_packed = false; if (!vkdev->info.support_fp16_storage()) opt.use_fp16_storage = false; if (!vkdev->info.support_fp16_uniform()) opt.use_fp16_uniform = false; @@ -1395,6 +1408,7 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectortypeindex == ncnn::LayerType::MultiHeadAttention) { fprintf(stderr, "fixme: skip gpu bf16 test for MultiHeadAttention\n"); + delete op; return 233; } } @@ -1471,12 +1485,24 @@ int test_layer_gpu(int typeindex, const ncnn::ParamDict& pd, const std::vectorload_model(mb); + ncnn::VkWeightAllocator g_weight_vkallocator(vkdev); + ncnn::VkWeightStagingAllocator g_weight_staging_vkallocator(vkdev); + + ncnn::VkAllocator* blob_vkallocator = vkdev->acquire_blob_allocator(); + ncnn::VkAllocator* staging_vkallocator = vkdev->acquire_staging_allocator(); + + opt.blob_vkallocator = blob_vkallocator; + opt.workspace_vkallocator = blob_vkallocator; + opt.staging_vkallocator = staging_vkallocator; + op->create_pipeline(opt); if (!op->support_vulkan) { op->destroy_pipeline(opt); delete op; + vkdev->reclaim_blob_allocator(blob_vkallocator); + vkdev->reclaim_staging_allocator(staging_vkallocator); return 233; } From f182b416672b82e454b39050af51385713f27072 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 1 Apr 2026 10:40:56 +0800 Subject: [PATCH 31/36] rotaryembed/tanh/selu/mish/hardswish/hardsigmoid/gelu/erf/elu/eltwise/dropout/quantize/dequantize/bnll x86 support bf16 storage (#6624) --- src/layer/x86/bnll_bf16s.h | 113 ++++ src/layer/x86/bnll_x86.cpp | 25 + src/layer/x86/bnll_x86.h | 5 +- src/layer/x86/bnll_x86_avx512bf16.cpp | 32 ++ src/layer/x86/dequantize_bf16s.h | 286 ++++++++++ src/layer/x86/dequantize_x86.cpp | 21 + src/layer/x86/dequantize_x86.h | 5 + src/layer/x86/dequantize_x86_avx512bf16.cpp | 27 + src/layer/x86/dropout_bf16s.h | 88 ++++ src/layer/x86/dropout_x86.cpp | 30 ++ src/layer/x86/dropout_x86.h | 5 + src/layer/x86/dropout_x86_avx512bf16.cpp | 27 + src/layer/x86/eltwise_bf16s.h | 516 +++++++++++++++++++ src/layer/x86/eltwise_x86.cpp | 30 ++ src/layer/x86/eltwise_x86.h | 5 + src/layer/x86/eltwise_x86_avx512bf16.cpp | 27 + src/layer/x86/elu_bf16s.h | 86 ++++ src/layer/x86/elu_x86.cpp | 23 + src/layer/x86/elu_x86.h | 5 + src/layer/x86/elu_x86_avx512bf16.cpp | 21 + src/layer/x86/erf_bf16s.h | 84 +++ src/layer/x86/erf_x86.cpp | 25 + src/layer/x86/erf_x86.h | 5 + src/layer/x86/erf_x86_avx512bf16.cpp | 32 ++ src/layer/x86/gelu_bf16s.h | 269 ++++++++++ src/layer/x86/gelu_x86.cpp | 25 + src/layer/x86/gelu_x86.h | 5 + src/layer/x86/gelu_x86_avx512bf16.cpp | 32 ++ src/layer/x86/hardsigmoid_bf16s.h | 110 ++++ src/layer/x86/hardsigmoid_x86.cpp | 23 + src/layer/x86/hardsigmoid_x86.h | 5 + src/layer/x86/hardsigmoid_x86_avx512bf16.cpp | 20 + src/layer/x86/hardswish_bf16s.h | 120 +++++ src/layer/x86/hardswish_x86.cpp | 23 + src/layer/x86/hardswish_x86.h | 5 + src/layer/x86/hardswish_x86_avx512bf16.cpp | 20 + src/layer/x86/mish_bf16s.h | 66 +++ src/layer/x86/mish_x86.cpp | 23 + src/layer/x86/mish_x86.h | 5 + src/layer/x86/mish_x86_avx512bf16.cpp | 19 + src/layer/x86/quantize_bf16s.h | 461 +++++++++++++++++ src/layer/x86/quantize_x86.cpp | 21 + src/layer/x86/quantize_x86.h | 5 + src/layer/x86/quantize_x86_avx512bf16.cpp | 27 + src/layer/x86/rotaryembed_bf16s.h | 364 +++++++++++++ src/layer/x86/rotaryembed_x86.cpp | 29 ++ src/layer/x86/rotaryembed_x86.h | 5 + src/layer/x86/rotaryembed_x86_avx512bf16.cpp | 17 + src/layer/x86/selu_bf16s.h | 108 ++++ src/layer/x86/selu_x86.cpp | 26 + src/layer/x86/selu_x86.h | 5 + src/layer/x86/selu_x86_avx512bf16.cpp | 32 ++ src/layer/x86/tanh_bf16s.h | 84 +++ src/layer/x86/tanh_x86.cpp | 23 + src/layer/x86/tanh_x86.h | 5 + src/layer/x86/tanh_x86_avx512bf16.cpp | 19 + 56 files changed, 3523 insertions(+), 1 deletion(-) create mode 100644 src/layer/x86/bnll_bf16s.h create mode 100644 src/layer/x86/bnll_x86_avx512bf16.cpp create mode 100644 src/layer/x86/dequantize_bf16s.h create mode 100644 src/layer/x86/dequantize_x86_avx512bf16.cpp create mode 100644 src/layer/x86/dropout_bf16s.h create mode 100644 src/layer/x86/dropout_x86_avx512bf16.cpp create mode 100644 src/layer/x86/eltwise_bf16s.h create mode 100644 src/layer/x86/eltwise_x86_avx512bf16.cpp create mode 100644 src/layer/x86/elu_bf16s.h create mode 100644 src/layer/x86/elu_x86_avx512bf16.cpp create mode 100644 src/layer/x86/erf_bf16s.h create mode 100644 src/layer/x86/erf_x86_avx512bf16.cpp create mode 100644 src/layer/x86/gelu_bf16s.h create mode 100644 src/layer/x86/gelu_x86_avx512bf16.cpp create mode 100644 src/layer/x86/hardsigmoid_bf16s.h create mode 100644 src/layer/x86/hardsigmoid_x86_avx512bf16.cpp create mode 100644 src/layer/x86/hardswish_bf16s.h create mode 100644 src/layer/x86/hardswish_x86_avx512bf16.cpp create mode 100644 src/layer/x86/mish_bf16s.h create mode 100644 src/layer/x86/mish_x86_avx512bf16.cpp create mode 100644 src/layer/x86/quantize_bf16s.h create mode 100644 src/layer/x86/quantize_x86_avx512bf16.cpp create mode 100644 src/layer/x86/rotaryembed_bf16s.h create mode 100644 src/layer/x86/rotaryembed_x86_avx512bf16.cpp create mode 100644 src/layer/x86/selu_bf16s.h create mode 100644 src/layer/x86/selu_x86_avx512bf16.cpp create mode 100644 src/layer/x86/tanh_bf16s.h create mode 100644 src/layer/x86/tanh_x86_avx512bf16.cpp diff --git a/src/layer/x86/bnll_bf16s.h b/src/layer/x86/bnll_bf16s.h new file mode 100644 index 00000000000..c085c27e93c --- /dev/null +++ b/src/layer/x86/bnll_bf16s.h @@ -0,0 +1,113 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void bnll_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void bnll_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + bnll_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _one_avx512 = _mm512_set1_ps(1.f); + __m512 _zero_avx512 = _mm512_setzero_ps(); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __mmask16 mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_GT_OQ); + __m512 _abs_p = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(_p), _mm512_set1_epi32(0x7fffffff))); + __m512 _tmp = log512_ps(_mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _abs_p)))); + _p = _mm512_mask_add_ps(_tmp, mask, _tmp, _p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + __mmask16 mask = _mm512_cmp_ps_mask(_p, _zero_avx512, _CMP_GT_OQ); + __m512 _abs_p = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(_p), _mm512_set1_epi32(0x7fffffff))); + __m512 _tmp = log512_ps(_mm512_add_ps(_one_avx512, exp512_ps(_mm512_sub_ps(_zero_avx512, _abs_p)))); + _p = _mm512_mask_add_ps(_tmp, mask, _tmp, _p); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _one_avx = _mm256_set1_ps(1.f); + __m256 _zero_avx = _mm256_setzero_ps(); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 mask = _mm256_cmp_ps(_p, _mm256_setzero_ps(), _CMP_GT_OQ); + __m256 _abs_p = _mm256_and_ps(_p, *(__m256*)_ps256_inv_sign_mask); + __m256 _tmp = log256_ps(_mm256_add_ps(_one_avx, exp256_ps(_mm256_sub_ps(_zero_avx, _abs_p)))); + __m256 _x = _mm256_and_ps(_p, mask); + _p = _mm256_add_ps(_x, _tmp); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _one = _mm_set1_ps(1.f); + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 mask = _mm_cmpgt_ps(_p, _zero); + __m128 _abs_p = _mm_and_ps(_p, *(__m128*)_ps_inv_sign_mask); + __m128 _tmp = log_ps(_mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _abs_p)))); + __m128 _x = _mm_and_ps(_p, mask); + _p = _mm_add_ps(_x, _tmp); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _one = _mm_set1_ps(1.f); + __m128 _zero = _mm_setzero_ps(); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 mask = _mm_cmpgt_ps(_p, _zero); + __m128 _abs_p = _mm_and_ps(_p, *(__m128*)_ps_inv_sign_mask); + __m128 _tmp = log_ps(_mm_add_ps(_one, exp_ps(_mm_sub_ps(_zero, _abs_p)))); + __m128 _x = _mm_and_ps(_p, mask); + _p = _mm_add_ps(_x, _tmp); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v > 0) + v = v + logf(1.f + expf(-v)); + else + v = logf(1.f + expf(v)); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/bnll_x86.cpp b/src/layer/x86/bnll_x86.cpp index 2e566c01c07..f30bc74595f 100644 --- a/src/layer/x86/bnll_x86.cpp +++ b/src/layer/x86/bnll_x86.cpp @@ -15,13 +15,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "bnll_bf16s.h" +#endif + BNLL_x86::BNLL_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -33,6 +44,11 @@ int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int elempack = bottom_top_blob.elempack; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -95,4 +111,13 @@ int BNLL_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int BNLL_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + bnll_bf16s(bottom_top_blob, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/bnll_x86.h b/src/layer/x86/bnll_x86.h index 521dfe4949c..88ae7c1c4ce 100644 --- a/src/layer/x86/bnll_x86.h +++ b/src/layer/x86/bnll_x86.h @@ -14,7 +14,10 @@ class BNLL_x86 : public BNLL BNLL_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; -public: +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/bnll_x86_avx512bf16.cpp b/src/layer/x86/bnll_x86_avx512bf16.cpp new file mode 100644 index 00000000000..15d0d3b58bd --- /dev/null +++ b/src/layer/x86/bnll_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "bnll_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "bnll_bf16s.h" + +void bnll_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + bnll_bf16s(a, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/dequantize_bf16s.h b/src/layer/x86/dequantize_bf16s.h new file mode 100644 index 00000000000..603db082f36 --- /dev/null +++ b/src/layer/x86/dequantize_bf16s.h @@ -0,0 +1,286 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void dequantize_forward_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Mat& bias_data, int bias_data_size, const Option& opt); +#endif + +static void dequantize_bf16(const int* intptr, unsigned short* ptr, const Mat& scale_data, const Mat& bias_data, int elemcount, int elempack) +{ + const int scale_data_size = scale_data.w; + const int bias_data_size = bias_data.w; + const int size = elemcount * elempack; + + float scale = scale_data[0]; +#if __SSE2__ + __m128 _scale0 = _mm_set1_ps(scale); +#if __AVX__ + __m256 _scale_avx = _mm256_set1_ps(scale); +#if __AVX512F__ + __m512 _scale_avx512 = _mm512_set1_ps(scale); +#endif // __AVX512F__ +#else // __AVX__ + __m128 _scale1 = _scale0; +#endif // __AVX__ + if (scale_data_size > 1) + { +#if __AVX512F__ + if (elempack == 16) + { + _scale_avx512 = _mm512_loadu_ps((const float*)scale_data); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX__ + _scale_avx = _mm256_loadu_ps((const float*)scale_data); +#if __AVX512F__ + _scale_avx512 = combine8x2_ps(_scale_avx, _scale_avx); +#endif // __AVX512F__ +#else // __AVX__ + _scale0 = _mm_loadu_ps((const float*)scale_data); + _scale1 = _mm_loadu_ps((const float*)scale_data + 4); +#endif // __AVX__ + } + if (elempack == 4) + { + _scale0 = _mm_loadu_ps((const float*)scale_data); +#if __AVX__ + _scale_avx = combine4x2_ps(_scale0, _scale0); +#if __AVX512F__ + _scale_avx512 = combine8x2_ps(_scale_avx, _scale_avx); +#endif // __AVX512F__ +#else // __AVX__ + _scale1 = _scale0; +#endif // __AVX__ + } + } +#endif // __SSE2__ + + if (bias_data_size == 0) + { + int i = 0; +#if __SSE2__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _v = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)intptr)); + _v = _mm512_mul_ps(_v, _scale_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_v)); + intptr += 16; + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { +#if __AVX__ + __m256 _v = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)intptr)); + _v = _mm256_mul_ps(_v, _scale_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_v)); +#else // __AVX__ + __m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); + __m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4))); + _v0 = _mm_mul_ps(_v0, _scale0); + _v1 = _mm_mul_ps(_v1, _scale1); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_sse(_v0, _v1)); +#endif // __AVX__ + intptr += 8; + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); + _v = _mm_mul_ps(_v, _scale0); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_v, _v)); + intptr += 4; + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(*intptr * scale); + intptr++; + ptr++; + } + } + else + { + float bias = bias_data[0]; +#if __SSE2__ + __m128 _bias0 = _mm_set1_ps(bias); +#if __AVX__ + __m256 _bias_avx = _mm256_set1_ps(bias); +#if __AVX512F__ + __m512 _bias_avx512 = _mm512_set1_ps(bias); +#endif // __AVX512F__ +#else // __AVX__ + __m128 _bias1 = _bias0; +#endif // __AVX__ + if (bias_data_size > 1) + { +#if __AVX512F__ + if (elempack == 16) + { + _bias_avx512 = _mm512_loadu_ps((const float*)bias_data); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX__ + _bias_avx = _mm256_loadu_ps((const float*)bias_data); +#if __AVX512F__ + _bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx); +#endif // __AVX512F__ +#else // __AVX__ + _bias0 = _mm_loadu_ps((const float*)bias_data); + _bias1 = _mm_loadu_ps((const float*)bias_data + 4); +#endif // __AVX__ + } + if (elempack == 4) + { + _bias0 = _mm_loadu_ps((const float*)bias_data); +#if __AVX__ + _bias_avx = combine4x2_ps(_bias0, _bias0); +#if __AVX512F__ + _bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx); +#endif // __AVX512F__ +#else // __AVX__ + _bias1 = _bias0; +#endif // __AVX__ + } + } +#endif // __SSE2__ + + int i = 0; +#if __SSE2__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _v = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)intptr)); + _v = _mm512_fmadd_ps(_v, _scale_avx512, _bias_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_v)); + intptr += 16; + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { +#if __AVX__ + __m256 _v = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)intptr)); + _v = _mm256_comp_fmadd_ps(_v, _scale_avx, _bias_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_v)); +#else // __AVX__ + __m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); + __m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4))); + _v0 = _mm_comp_fmadd_ps(_v0, _scale0, _bias0); + _v1 = _mm_comp_fmadd_ps(_v1, _scale1, _bias1); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_sse(_v0, _v1)); +#endif // __AVX__ + intptr += 8; + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); + _v = _mm_comp_fmadd_ps(_v, _scale0, _bias0); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_v, _v)); + intptr += 4; + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *ptr = float32_to_bfloat16(*intptr * scale + bias); + intptr++; + ptr++; + } + } +} + +static int dequantize_forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Mat& bias_data, int bias_data_size, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + dequantize_forward_bf16s_avx512bf16(bottom_blob, top_blob, scale_data, scale_data_size, bias_data, bias_data_size, opt); + return 0; + } +#endif + + const int dims = bottom_blob.dims; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + const size_t out_elemsize = 2u * elempack; + + if (dims == 1) + { + top_blob.create(w, out_elemsize, elempack, opt.blob_allocator); + } + else if (dims == 2) + { + top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator); + } + else if (dims == 3) + { + top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + if (dims == 1) + { + const int wp = std::max(1, w / opt.num_threads); + const int nn_w = (w + wp - 1) / wp; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_w; ii++) + { + const int i = ii * wp; + + const int* intptr = (const int*)bottom_blob + i * elempack; + unsigned short* ptr = (unsigned short*)top_blob + i * elempack; + + // assert scale_data_size == 1 + // assert bias_data_size == 0 || bias_data_size == 1 + + const int size = std::min(w - i, wp) * elempack; + + dequantize_bf16(intptr, ptr, scale_data, bias_data, size, 1); + } + } + + if (dims == 2) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const int* intptr = bottom_blob.row(i); + unsigned short* ptr = top_blob.row(i); + + const Mat scale_data_i = scale_data_size > 1 ? scale_data.range(i * elempack, elempack) : scale_data; + const Mat bias_data_i = bias_data_size > 1 ? bias_data.range(i * elempack, elempack) : bias_data; + + dequantize_bf16(intptr, ptr, scale_data_i, bias_data_i, w, elempack); + } + } + + if (dims == 3) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const int* intptr = bottom_blob.channel(q); + unsigned short* ptr = top_blob.channel(q); + + const Mat scale_data_q = scale_data_size > 1 ? scale_data.range(q * elempack, elempack) : scale_data; + const Mat bias_data_q = bias_data_size > 1 ? bias_data.range(q * elempack, elempack) : bias_data; + + dequantize_bf16(intptr, ptr, scale_data_q, bias_data_q, w * h, elempack); + } + } + + return 0; +} diff --git a/src/layer/x86/dequantize_x86.cpp b/src/layer/x86/dequantize_x86.cpp index 81625da9d9e..a8f044cd091 100644 --- a/src/layer/x86/dequantize_x86.cpp +++ b/src/layer/x86/dequantize_x86.cpp @@ -12,13 +12,22 @@ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "dequantize_bf16s.h" +#endif + Dequantize_x86::Dequantize_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, const Mat& bias_data, int elemcount, int elempack) @@ -219,6 +228,11 @@ static void dequantize(const int* intptr, float* ptr, const Mat& scale_data, con int Dequantize_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage) + return forward_bf16s(bottom_blob, top_blob, opt); +#endif + const int dims = bottom_blob.dims; const int w = bottom_blob.w; const int h = bottom_blob.h; @@ -284,4 +298,11 @@ int Dequantize_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& return 0; } +#if NCNN_BF16 +int Dequantize_x86::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + return dequantize_forward_bf16s(bottom_blob, top_blob, scale_data, scale_data_size, bias_data, bias_data_size, opt); +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/dequantize_x86.h b/src/layer/x86/dequantize_x86.h index f2f1d2f049d..0dd5b76220d 100644 --- a/src/layer/x86/dequantize_x86.h +++ b/src/layer/x86/dequantize_x86.h @@ -14,6 +14,11 @@ class Dequantize_x86 : public Dequantize Dequantize_x86(); virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/dequantize_x86_avx512bf16.cpp b/src/layer/x86/dequantize_x86_avx512bf16.cpp new file mode 100644 index 00000000000..0d821807f3a --- /dev/null +++ b/src/layer/x86/dequantize_x86_avx512bf16.cpp @@ -0,0 +1,27 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "dequantize_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "dequantize_bf16s.h" + +void dequantize_forward_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Mat& bias_data, int bias_data_size, const Option& opt) +{ + dequantize_forward_bf16s(bottom_blob, top_blob, scale_data, scale_data_size, bias_data, bias_data_size, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/dropout_bf16s.h b/src/layer/x86/dropout_bf16s.h new file mode 100644 index 00000000000..9145b28e5e3 --- /dev/null +++ b/src/layer/x86/dropout_bf16s.h @@ -0,0 +1,88 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void dropout_bf16s_avx512bf16(Mat& a, float scale, const Option& opt); +#endif + +static void dropout_bf16s(Mat& a, float scale, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + dropout_bf16s_avx512bf16(a, scale, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _scale_avx512 = _mm512_set1_ps(scale); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_mul_ps(_p, _scale_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = _mm512_mul_ps(_p, _scale_avx512); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _scale_avx = _mm256_set1_ps(scale); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_mul_ps(_p, _scale_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _scale = _mm_set1_ps(scale); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _scale); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _scale = _mm_set1_ps(scale); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _scale); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = v * scale; + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/dropout_x86.cpp b/src/layer/x86/dropout_x86.cpp index 938dd8321af..1198c8f08d1 100644 --- a/src/layer/x86/dropout_x86.cpp +++ b/src/layer/x86/dropout_x86.cpp @@ -9,14 +9,24 @@ #include #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" namespace ncnn { +#if NCNN_BF16 +#include "dropout_bf16s.h" +#endif + Dropout_x86::Dropout_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Dropout_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -26,6 +36,11 @@ int Dropout_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #if __SSE2__ int dims = bottom_top_blob.dims; int elempack = bottom_top_blob.elempack; @@ -166,4 +181,19 @@ int Dropout_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return Dropout::forward_inplace(bottom_top_blob, opt); } +#if NCNN_BF16 +int Dropout_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + if (scale == 1.f) + { + return 0; + } + + dropout_bf16s(bottom_top_blob, scale, opt); + + return 0; +} + +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/dropout_x86.h b/src/layer/x86/dropout_x86.h index b13f441c93c..5c3a9bf9d5d 100644 --- a/src/layer/x86/dropout_x86.h +++ b/src/layer/x86/dropout_x86.h @@ -14,6 +14,11 @@ class Dropout_x86 : public Dropout Dropout_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/dropout_x86_avx512bf16.cpp b/src/layer/x86/dropout_x86_avx512bf16.cpp new file mode 100644 index 00000000000..52cf5052219 --- /dev/null +++ b/src/layer/x86/dropout_x86_avx512bf16.cpp @@ -0,0 +1,27 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "dropout_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "dropout_bf16s.h" + +void dropout_bf16s_avx512bf16(Mat& a, float scale, const Option& opt) +{ + dropout_bf16s(a, scale, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/eltwise_bf16s.h b/src/layer/x86/eltwise_bf16s.h new file mode 100644 index 00000000000..c79078630e5 --- /dev/null +++ b/src/layer/x86/eltwise_bf16s.h @@ -0,0 +1,516 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void eltwise_bf16s_avx512bf16(const std::vector& bottom_blobs, Mat& top_blob, int op_type, const Mat& coeffs, const Option& opt); +#endif + +static void eltwise_bf16s(const std::vector& bottom_blobs, Mat& top_blob, int op_type, const Mat& coeffs, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + eltwise_bf16s_avx512bf16(bottom_blobs, top_blob, op_type, coeffs, opt); + return; + } +#endif + + const Mat& bottom_blob = bottom_blobs[0]; + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + int elempack = bottom_blob.elempack; + int size = w * h * d * elempack; + + if (op_type == 0) // Operation_PROD + { + // first blob + const Mat& bottom_blob1 = bottom_blobs[1]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + const unsigned short* ptr1 = bottom_blob1.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + _p = _mm512_mul_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + _p = _mm256_mul_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + _p = _mm_mul_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * bfloat16_to_float32(*ptr1)); + + ptr++; + ptr1++; + outptr++; + } + } + + for (size_t b = 2; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob2 = bottom_blobs[b]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob2.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)outptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_mul_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)outptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_mul_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)outptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_mul_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*outptr) * bfloat16_to_float32(*ptr)); + + ptr++; + outptr++; + } + } + } + } + if (op_type == 1) // Operation_SUM + { + if (coeffs.w == 0) + { + // first blob + const Mat& bottom_blob1 = bottom_blobs[1]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + const unsigned short* ptr1 = bottom_blob1.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + _p = _mm512_add_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + _p = _mm256_add_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + _p = _mm_add_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) + bfloat16_to_float32(*ptr1)); + + ptr++; + ptr1++; + outptr++; + } + } + + for (size_t b = 2; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob2 = bottom_blobs[b]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob2.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)outptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_add_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)outptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_add_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)outptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_add_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*outptr) + bfloat16_to_float32(*ptr)); + + ptr++; + outptr++; + } + } + } + } + else + { + // first blob + const Mat& bottom_blob1 = bottom_blobs[1]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + const unsigned short* ptr1 = bottom_blob1.channel(q); + unsigned short* outptr = top_blob.channel(q); + + const float coeff0 = coeffs[0]; + const float coeff1 = coeffs[1]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _coeff0_avx512 = _mm512_set1_ps(coeff0); + __m512 _coeff1_avx512 = _mm512_set1_ps(coeff1); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + _p = _mm512_mul_ps(_p, _coeff0_avx512); + _p = _mm512_fmadd_ps(_p1, _coeff1_avx512, _p); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + __m256 _coeff0_avx = _mm256_set1_ps(coeff0); + __m256 _coeff1_avx = _mm256_set1_ps(coeff1); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + _p = _mm256_mul_ps(_p, _coeff0_avx); + _p = _mm256_comp_fmadd_ps(_p1, _coeff1_avx, _p); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + __m128 _coeff0 = _mm_set1_ps(coeff0); + __m128 _coeff1 = _mm_set1_ps(coeff1); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + _p = _mm_mul_ps(_p, _coeff0); + _p1 = _mm_mul_ps(_p1, _coeff1); + _p = _mm_add_ps(_p1, _p); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*ptr) * coeff0 + bfloat16_to_float32(*ptr1) * coeff1); + + ptr++; + ptr1++; + outptr++; + } + } + + for (size_t b = 2; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob2 = bottom_blobs[b]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob2.channel(q); + unsigned short* outptr = top_blob.channel(q); + + const float coeff = coeffs[b]; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _coeff_avx512 = _mm512_set1_ps(coeff); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)outptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmadd_ps(_p1, _coeff_avx512, _p); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + __m256 _coeff_avx = _mm256_set1_ps(coeff); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)outptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmadd_ps(_p1, _coeff_avx, _p); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + __m128 _coeff = _mm_set1_ps(coeff); + for (; i + 3 < size; i += 4) + { + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)outptr)); + _p1 = _mm_mul_ps(_p1, _coeff); + _p = _mm_add_ps(_p1, _p); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(bfloat16_to_float32(*outptr) + bfloat16_to_float32(*ptr) * coeff); + + ptr++; + outptr++; + } + } + } + } + } + if (op_type == 2) // Operation_MAX + { + // first blob + const Mat& bottom_blob1 = bottom_blobs[1]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + const unsigned short* ptr1 = bottom_blob1.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + _p = _mm512_max_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + ptr1 += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + _p = _mm256_max_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + ptr1 += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + _p = _mm_max_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + ptr1 += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(std::max(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); + + ptr++; + ptr1++; + outptr++; + } + } + + for (size_t b = 2; b < bottom_blobs.size(); b++) + { + const Mat& bottom_blob2 = bottom_blobs[b]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob2.channel(q); + unsigned short* outptr = top_blob.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)outptr)); + __m512 _p1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_max_ps(_p, _p1); + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_p)); + + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)outptr)); + __m256 _p1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_max_ps(_p, _p1); + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(_p)); + + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)outptr)); + __m128 _p1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_max_ps(_p, _p1); + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_p, _p)); + + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(std::max(bfloat16_to_float32(*ptr), bfloat16_to_float32(*outptr))); + + ptr++; + outptr++; + } + } + } + } +} diff --git a/src/layer/x86/eltwise_x86.cpp b/src/layer/x86/eltwise_x86.cpp index c85b7f32e6b..22b7f56a567 100644 --- a/src/layer/x86/eltwise_x86.cpp +++ b/src/layer/x86/eltwise_x86.cpp @@ -11,17 +11,30 @@ #endif // __SSE2__ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "eltwise_bf16s.h" +#endif + Eltwise_x86::Eltwise_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Eltwise_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_blobs[0].elembits() == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif const Mat& bottom_blob = bottom_blobs[0]; int w = bottom_blob.w; int h = bottom_blob.h; @@ -529,4 +542,21 @@ int Eltwise_x86::forward(const std::vector& bottom_blobs, std::vector& return 0; } +#if NCNN_BF16 +int Eltwise_x86::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + + Mat& top_blob = top_blobs[0]; + top_blob.create_like(bottom_blob, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + eltwise_bf16s(bottom_blobs, top_blob, op_type, coeffs, opt); + + return 0; +} + +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/eltwise_x86.h b/src/layer/x86/eltwise_x86.h index b9729fa29a4..abedd25c7e7 100644 --- a/src/layer/x86/eltwise_x86.h +++ b/src/layer/x86/eltwise_x86.h @@ -14,6 +14,11 @@ class Eltwise_x86 : public Eltwise Eltwise_x86(); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/eltwise_x86_avx512bf16.cpp b/src/layer/x86/eltwise_x86_avx512bf16.cpp new file mode 100644 index 00000000000..f70951fe8f7 --- /dev/null +++ b/src/layer/x86/eltwise_x86_avx512bf16.cpp @@ -0,0 +1,27 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "eltwise_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "eltwise_bf16s.h" + +void eltwise_bf16s_avx512bf16(const std::vector& bottom_blobs, Mat& top_blob, int op_type, const Mat& coeffs, const Option& opt) +{ + eltwise_bf16s(bottom_blobs, top_blob, op_type, coeffs, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/elu_bf16s.h b/src/layer/x86/elu_bf16s.h new file mode 100644 index 00000000000..e522a6f63ff --- /dev/null +++ b/src/layer/x86/elu_bf16s.h @@ -0,0 +1,86 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void elu_bf16s_avx512bf16(Mat& a, float alpha, const Option& opt); +#endif + +static void elu_bf16s(Mat& a, float alpha, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + elu_bf16s_avx512bf16(a, alpha, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _alpha_avx512 = _mm512_set1_ps(alpha); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(elu_avx512(_p, _alpha_avx512))); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(elu_avx512(_p, _alpha_avx512))); + i += remain; + } +#else // __AVX512F__ + __m256 _alpha_avx = _mm256_set1_ps(alpha); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(elu_avx(_p, _alpha_avx))); + ptr += 8; + } + __m128 _alpha_sse = _mm_set1_ps(alpha); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = elu_sse(_p, _alpha_sse); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _alpha_sse = _mm_set1_ps(alpha); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = elu_sse(_p, _alpha_sse); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < 0.f) + v = alpha * (expf(v) - 1.f); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/elu_x86.cpp b/src/layer/x86/elu_x86.cpp index a683bc58f91..eaf8710df21 100644 --- a/src/layer/x86/elu_x86.cpp +++ b/src/layer/x86/elu_x86.cpp @@ -5,17 +5,31 @@ #include "x86_activation.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "elu_bf16s.h" +#endif + ELU_x86::ELU_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int ELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + int w = bottom_top_blob.w; int h = bottom_top_blob.h; int d = bottom_top_blob.d; @@ -71,4 +85,13 @@ int ELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int ELU_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + elu_bf16s(bottom_top_blob, alpha, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/elu_x86.h b/src/layer/x86/elu_x86.h index 799ca5feb6b..a7c31313841 100644 --- a/src/layer/x86/elu_x86.h +++ b/src/layer/x86/elu_x86.h @@ -14,6 +14,11 @@ class ELU_x86 : public ELU ELU_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/elu_x86_avx512bf16.cpp b/src/layer/x86/elu_x86_avx512bf16.cpp new file mode 100644 index 00000000000..c8e3aef0259 --- /dev/null +++ b/src/layer/x86/elu_x86_avx512bf16.cpp @@ -0,0 +1,21 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "elu_x86.h" + +#include "x86_activation.h" +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "elu_bf16s.h" + +void elu_bf16s_avx512bf16(Mat& a, float alpha, const Option& opt) +{ + elu_bf16s(a, alpha, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/erf_bf16s.h b/src/layer/x86/erf_bf16s.h new file mode 100644 index 00000000000..5f61bc6e1e2 --- /dev/null +++ b/src/layer/x86/erf_bf16s.h @@ -0,0 +1,84 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void erf_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void erf_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + erf_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = erf512_ps(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = erf512_ps(_p); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = erf256_ps(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = erf_ps(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = erf_ps(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = erff(v); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/erf_x86.cpp b/src/layer/x86/erf_x86.cpp index 5a1c396cef4..0d5eb693d79 100644 --- a/src/layer/x86/erf_x86.cpp +++ b/src/layer/x86/erf_x86.cpp @@ -15,13 +15,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "erf_bf16s.h" +#endif + Erf_x86::Erf_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Erf_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -33,6 +44,11 @@ int Erf_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int elempack = bottom_top_blob.elempack; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -86,4 +102,13 @@ int Erf_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int Erf_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + erf_bf16s(bottom_top_blob, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/erf_x86.h b/src/layer/x86/erf_x86.h index f8366bf50b3..635bacfe0c6 100644 --- a/src/layer/x86/erf_x86.h +++ b/src/layer/x86/erf_x86.h @@ -14,6 +14,11 @@ class Erf_x86 : public Erf Erf_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/erf_x86_avx512bf16.cpp b/src/layer/x86/erf_x86_avx512bf16.cpp new file mode 100644 index 00000000000..1840dd87241 --- /dev/null +++ b/src/layer/x86/erf_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "erf_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "erf_bf16s.h" + +void erf_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + erf_bf16s(a, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/gelu_bf16s.h b/src/layer/x86/gelu_bf16s.h new file mode 100644 index 00000000000..941a18052a1 --- /dev/null +++ b/src/layer/x86/gelu_bf16s.h @@ -0,0 +1,269 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void gelu_bf16s_avx512bf16(Mat& a, int fast_gelu, const Option& opt); +#endif + +static void gelu_bf16s(Mat& a, int fast_gelu, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + gelu_bf16s_avx512bf16(a, fast_gelu, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (fast_gelu) + { + __m512 _half512 = _mm512_set1_ps(0.5f); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _fast1c512 = _mm512_set1_ps(0.79788452f); + __m512 _fast2c512 = _mm512_set1_ps(0.044715f); + for (; i + 15 < size; i += 16) + { + __m512 _pLoad = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + + __m512 _cube = _mm512_mul_ps(_pLoad, _pLoad); + _cube = _mm512_mul_ps(_pLoad, _cube); + + __m512 _blob = _mm512_mul_ps(_fast2c512, _cube); + _blob = _mm512_add_ps(_pLoad, _blob); + _blob = _mm512_mul_ps(_fast1c512, _blob); + _blob = tanh512_ps(_blob); + _blob = _mm512_add_ps(_one512, _blob); + + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_blob)); + + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _pLoad = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + + __m512 _cube = _mm512_mul_ps(_pLoad, _pLoad); + _cube = _mm512_mul_ps(_pLoad, _cube); + + __m512 _blob = _mm512_mul_ps(_fast2c512, _cube); + _blob = _mm512_add_ps(_pLoad, _blob); + _blob = _mm512_mul_ps(_fast1c512, _blob); + _blob = tanh512_ps(_blob); + _blob = _mm512_add_ps(_one512, _blob); + + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_blob)); + i += remain; + } + } + else + { + __m512 _half512 = _mm512_set1_ps(0.5f); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _inv_sqrt2_512 = _mm512_set1_ps(0.70710678f); + for (; i + 15 < size; i += 16) + { + __m512 _pLoad = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + + __m512 _erf = erf512_ps(_mm512_mul_ps(_pLoad, _inv_sqrt2_512)); + __m512 _blob = _mm512_add_ps(_one512, _erf); + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_blob)); + + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _pLoad = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + + __m512 _erf = erf512_ps(_mm512_mul_ps(_pLoad, _inv_sqrt2_512)); + __m512 _blob = _mm512_add_ps(_one512, _erf); + _blob = _mm512_mul_ps(_half512, _mm512_mul_ps(_blob, _pLoad)); + + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_blob)); + i += remain; + } + } +#else // __AVX512F__ + if (fast_gelu) + { + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); + __m256 _fast2c256 = _mm256_set1_ps(0.044715f); + for (; i + 7 < size; i += 8) + { + __m256 _pLoad = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + + __m256 _cube = _mm256_mul_ps(_pLoad, _pLoad); + _cube = _mm256_mul_ps(_pLoad, _cube); + + __m256 _blob = _mm256_mul_ps(_fast2c256, _cube); + _blob = _mm256_add_ps(_pLoad, _blob); + _blob = _mm256_mul_ps(_fast1c256, _blob); + _blob = tanh256_ps(_blob); + _blob = _mm256_add_ps(_one256, _blob); + + _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); + + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_blob)); + + ptr += 8; + } + } + else + { + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _inv_sqrt2_256 = _mm256_set1_ps(0.70710678f); + for (; i + 7 < size; i += 8) + { + __m256 _pLoad = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + + __m256 _erf = erf256_ps(_mm256_mul_ps(_pLoad, _inv_sqrt2_256)); + __m256 _blob = _mm256_add_ps(_one256, _erf); + _blob = _mm256_mul_ps(_half256, _mm256_mul_ps(_blob, _pLoad)); + + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_blob)); + + ptr += 8; + } + } + if (fast_gelu) + { + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _fast1c128 = _mm_set1_ps(0.79788452f); + __m128 _fast2c128 = _mm_set1_ps(0.044715f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + + __m128 _cube = _mm_mul_ps(_pLoad, _pLoad); + _cube = _mm_mul_ps(_pLoad, _cube); + + __m128 _blob = _mm_mul_ps(_fast2c128, _cube); + _blob = _mm_add_ps(_pLoad, _blob); + _blob = _mm_mul_ps(_fast1c128, _blob); + _blob = tanh_ps(_blob); + _blob = _mm_add_ps(_one128, _blob); + + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_blob, _blob)); + + ptr += 4; + } + } + else + { + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _inv_sqrt2_128 = _mm_set1_ps(0.70710678f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + + __m128 _erf = erf_ps(_mm_mul_ps(_pLoad, _inv_sqrt2_128)); + __m128 _blob = _mm_add_ps(_one128, _erf); + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_blob, _blob)); + + ptr += 4; + } + } +#endif // __AVX512F__ +#else // __AVX__ + if (fast_gelu) + { + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _fast1c128 = _mm_set1_ps(0.79788452f); + __m128 _fast2c128 = _mm_set1_ps(0.044715f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + + __m128 _cube = _mm_mul_ps(_pLoad, _pLoad); + _cube = _mm_mul_ps(_pLoad, _cube); + + __m128 _blob = _mm_mul_ps(_fast2c128, _cube); + _blob = _mm_add_ps(_pLoad, _blob); + _blob = _mm_mul_ps(_fast1c128, _blob); + _blob = tanh_ps(_blob); + _blob = _mm_add_ps(_one128, _blob); + + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_blob, _blob)); + + ptr += 4; + } + } + else + { + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _inv_sqrt2_128 = _mm_set1_ps(0.70710678f); + for (; i + 3 < size; i += 4) + { + __m128 _pLoad = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + + __m128 _erf = erf_ps(_mm_mul_ps(_pLoad, _inv_sqrt2_128)); + __m128 _blob = _mm_add_ps(_one128, _erf); + _blob = _mm_mul_ps(_half128, _mm_mul_ps(_blob, _pLoad)); + + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_blob, _blob)); + + ptr += 4; + } + } +#endif // __AVX__ +#endif // __SSE2__ + if (fast_gelu) + { + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = 0.5f * v * (1.0f + tanhf(0.79788452f * (v + 0.044715f * v * v * v))); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } + else + { + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = 0.5f * v * (1.0f + erff(0.70710678f * v)); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } + } +} diff --git a/src/layer/x86/gelu_x86.cpp b/src/layer/x86/gelu_x86.cpp index c4cdb36b7b9..dfcc3bd104f 100644 --- a/src/layer/x86/gelu_x86.cpp +++ b/src/layer/x86/gelu_x86.cpp @@ -15,13 +15,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "gelu_bf16s.h" +#endif + GELU_x86::GELU_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int GELU_x86::create_pipeline(const Option& /*opt*/) @@ -38,6 +49,11 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int channels = bottom_top_blob.c; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -206,4 +222,13 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int GELU_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + gelu_bf16s(bottom_top_blob, fast_gelu, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/gelu_x86.h b/src/layer/x86/gelu_x86.h index df65c6f3e31..8523d9dfb7b 100644 --- a/src/layer/x86/gelu_x86.h +++ b/src/layer/x86/gelu_x86.h @@ -15,6 +15,11 @@ class GELU_x86 : public GELU virtual int create_pipeline(const Option& opt); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/gelu_x86_avx512bf16.cpp b/src/layer/x86/gelu_x86_avx512bf16.cpp new file mode 100644 index 00000000000..739917ab6ac --- /dev/null +++ b/src/layer/x86/gelu_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "gelu_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "gelu_bf16s.h" + +void gelu_bf16s_avx512bf16(Mat& a, int fast_gelu, const Option& opt) +{ + gelu_bf16s(a, fast_gelu, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/hardsigmoid_bf16s.h b/src/layer/x86/hardsigmoid_bf16s.h new file mode 100644 index 00000000000..7636267e5c2 --- /dev/null +++ b/src/layer/x86/hardsigmoid_bf16s.h @@ -0,0 +1,110 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void hardsigmoid_bf16s_avx512bf16(Mat& a, float alpha, float beta, const Option& opt); +#endif + +static void hardsigmoid_bf16s(Mat& a, float alpha, float beta, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + hardsigmoid_bf16s_avx512bf16(a, alpha, beta, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _alpha_avx512 = _mm512_set1_ps(alpha); + __m512 _beta_avx512 = _mm512_set1_ps(beta); + __m512 _zero_avx512 = _mm512_setzero_ps(); + __m512 _one_avx512 = _mm512_set1_ps(1.f); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = _mm512_fmadd_ps(_p, _alpha_avx512, _beta_avx512); + _p = _mm512_max_ps(_p, _zero_avx512); + _p = _mm512_min_ps(_p, _one_avx512); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = _mm512_fmadd_ps(_p, _alpha_avx512, _beta_avx512); + _p = _mm512_max_ps(_p, _zero_avx512); + _p = _mm512_min_ps(_p, _one_avx512); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + __m256 _alpha_avx = _mm256_set1_ps(alpha); + __m256 _beta_avx = _mm256_set1_ps(beta); + __m256 _zero_avx = _mm256_setzero_ps(); + __m256 _one_avx = _mm256_set1_ps(1.f); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = _mm256_comp_fmadd_ps(_p, _alpha_avx, _beta_avx); + _p = _mm256_max_ps(_p, _zero_avx); + _p = _mm256_min_ps(_p, _one_avx); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + __m128 _alpha_sse = _mm_set1_ps(alpha); + __m128 _beta_sse = _mm_set1_ps(beta); + __m128 _zero = _mm_setzero_ps(); + __m128 _one = _mm_set1_ps(1.f); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmadd_ps(_p, _alpha_sse, _beta_sse); + _p = _mm_max_ps(_p, _zero); + _p = _mm_min_ps(_p, _one); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _alpha_sse = _mm_set1_ps(alpha); + __m128 _beta_sse = _mm_set1_ps(beta); + __m128 _zero = _mm_setzero_ps(); + __m128 _one = _mm_set1_ps(1.f); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = _mm_comp_fmadd_ps(_p, _alpha_sse, _beta_sse); + _p = _mm_max_ps(_p, _zero); + _p = _mm_min_ps(_p, _one); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = std::min(std::max(v * alpha + beta, 0.f), 1.f); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/hardsigmoid_x86.cpp b/src/layer/x86/hardsigmoid_x86.cpp index f2a1f818d16..df70ee196aa 100644 --- a/src/layer/x86/hardsigmoid_x86.cpp +++ b/src/layer/x86/hardsigmoid_x86.cpp @@ -12,18 +12,32 @@ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "hardsigmoid_bf16s.h" +#endif + HardSigmoid_x86::HardSigmoid_x86() { #if __SSE2__ support_packing = true; support_any_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int HardSigmoid_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + const int w = bottom_top_blob.w; const int h = bottom_top_blob.h; const int d = bottom_top_blob.d; @@ -94,4 +108,13 @@ int HardSigmoid_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) co return 0; } +#if NCNN_BF16 +int HardSigmoid_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + hardsigmoid_bf16s(bottom_top_blob, alpha, beta, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/hardsigmoid_x86.h b/src/layer/x86/hardsigmoid_x86.h index f52063c4cd7..1875cf3e6f3 100644 --- a/src/layer/x86/hardsigmoid_x86.h +++ b/src/layer/x86/hardsigmoid_x86.h @@ -14,6 +14,11 @@ class HardSigmoid_x86 : public HardSigmoid HardSigmoid_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp b/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp new file mode 100644 index 00000000000..284733672b9 --- /dev/null +++ b/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp @@ -0,0 +1,20 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "hardsigmoid_x86.h" + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "hardsigmoid_bf16s.h" + +void hardsigmoid_bf16s_avx512bf16(Mat& a, float alpha, float beta, const Option& opt) +{ + hardsigmoid_bf16s(a, alpha, beta, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/hardswish_bf16s.h b/src/layer/x86/hardswish_bf16s.h new file mode 100644 index 00000000000..03943a3d3e7 --- /dev/null +++ b/src/layer/x86/hardswish_bf16s.h @@ -0,0 +1,120 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void hardswish_bf16s_avx512bf16(Mat& a, float alpha, float beta, float lower, float upper, const Option& opt); +#endif + +static void hardswish_bf16s(Mat& a, float alpha, float beta, float lower, float upper, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + hardswish_bf16s_avx512bf16(a, alpha, beta, lower, upper, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _alpha_avx512 = _mm512_set1_ps(alpha); + __m512 _beta_avx512 = _mm512_set1_ps(beta); + __m512 _zero_avx512 = _mm512_setzero_ps(); + __m512 _one_avx512 = _mm512_set1_ps(1.f); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 _ans = _mm512_fmadd_ps(_p, _alpha_avx512, _beta_avx512); + _ans = _mm512_max_ps(_ans, _zero_avx512); + _ans = _mm512_min_ps(_ans, _one_avx512); + _ans = _mm512_mul_ps(_ans, _p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_ans)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + __m512 _ans = _mm512_fmadd_ps(_p, _alpha_avx512, _beta_avx512); + _ans = _mm512_max_ps(_ans, _zero_avx512); + _ans = _mm512_min_ps(_ans, _one_avx512); + _ans = _mm512_mul_ps(_ans, _p); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_ans)); + i += remain; + } +#else // __AVX512F__ + __m256 _alpha_avx = _mm256_set1_ps(alpha); + __m256 _beta_avx = _mm256_set1_ps(beta); + __m256 _zero_avx = _mm256_setzero_ps(); + __m256 _one_avx = _mm256_set1_ps(1.f); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _ans = _mm256_comp_fmadd_ps(_p, _alpha_avx, _beta_avx); + _ans = _mm256_max_ps(_ans, _zero_avx); + _ans = _mm256_min_ps(_ans, _one_avx); + _ans = _mm256_mul_ps(_ans, _p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_ans)); + ptr += 8; + } + __m128 _alpha_sse = _mm_set1_ps(alpha); + __m128 _beta_sse = _mm_set1_ps(beta); + __m128 _zero = _mm_setzero_ps(); + __m128 _one = _mm_set1_ps(1.f); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _ans = _mm_comp_fmadd_ps(_p, _alpha_sse, _beta_sse); + _ans = _mm_max_ps(_ans, _zero); + _ans = _mm_min_ps(_ans, _one); + _ans = _mm_mul_ps(_ans, _p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_ans, _ans)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + __m128 _alpha_sse = _mm_set1_ps(alpha); + __m128 _beta_sse = _mm_set1_ps(beta); + __m128 _zero = _mm_setzero_ps(); + __m128 _one = _mm_set1_ps(1.f); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _ans = _mm_comp_fmadd_ps(_p, _alpha_sse, _beta_sse); + _ans = _mm_max_ps(_ans, _zero); + _ans = _mm_min_ps(_ans, _one); + _ans = _mm_mul_ps(_ans, _p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_ans, _ans)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < lower) + v = 0.f; + else if (v > upper) + ; + else + v = v * (v * alpha + beta); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/hardswish_x86.cpp b/src/layer/x86/hardswish_x86.cpp index 08697be9139..220050a1620 100644 --- a/src/layer/x86/hardswish_x86.cpp +++ b/src/layer/x86/hardswish_x86.cpp @@ -12,17 +12,31 @@ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "hardswish_bf16s.h" +#endif + HardSwish_x86::HardSwish_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int HardSwish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + int w = bottom_top_blob.w; int h = bottom_top_blob.h; int d = bottom_top_blob.d; @@ -96,4 +110,13 @@ int HardSwish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons return 0; } +#if NCNN_BF16 +int HardSwish_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + hardswish_bf16s(bottom_top_blob, alpha, beta, lower, upper, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/hardswish_x86.h b/src/layer/x86/hardswish_x86.h index dd359f6d9e6..8271875af88 100644 --- a/src/layer/x86/hardswish_x86.h +++ b/src/layer/x86/hardswish_x86.h @@ -14,6 +14,11 @@ class HardSwish_x86 : public HardSwish HardSwish_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/hardswish_x86_avx512bf16.cpp b/src/layer/x86/hardswish_x86_avx512bf16.cpp new file mode 100644 index 00000000000..32cbde6887e --- /dev/null +++ b/src/layer/x86/hardswish_x86_avx512bf16.cpp @@ -0,0 +1,20 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "hardswish_x86.h" + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "hardswish_bf16s.h" + +void hardswish_bf16s_avx512bf16(Mat& a, float alpha, float beta, float lower, float upper, const Option& opt) +{ + hardswish_bf16s(a, alpha, beta, lower, upper, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/mish_bf16s.h b/src/layer/x86/mish_bf16s.h new file mode 100644 index 00000000000..d2af355a80b --- /dev/null +++ b/src/layer/x86/mish_bf16s.h @@ -0,0 +1,66 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void mish_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void mish_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + mish_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = mish_avx512(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = mish_avx(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = mish_sse(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = v * tanhf(logf(expf(v) + 1.f)); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/mish_x86.cpp b/src/layer/x86/mish_x86.cpp index 279faba3af6..b59b7d8190e 100644 --- a/src/layer/x86/mish_x86.cpp +++ b/src/layer/x86/mish_x86.cpp @@ -5,13 +5,22 @@ #include "x86_activation.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "mish_bf16s.h" +#endif + Mish_x86::Mish_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -23,6 +32,11 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int elempack = bottom_top_blob.elempack; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -66,4 +80,13 @@ int Mish_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int Mish_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + mish_bf16s(bottom_top_blob, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/mish_x86.h b/src/layer/x86/mish_x86.h index dc545339c7e..28139e9745b 100644 --- a/src/layer/x86/mish_x86.h +++ b/src/layer/x86/mish_x86.h @@ -14,6 +14,11 @@ class Mish_x86 : public Mish Mish_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/mish_x86_avx512bf16.cpp b/src/layer/x86/mish_x86_avx512bf16.cpp new file mode 100644 index 00000000000..32cf22af607 --- /dev/null +++ b/src/layer/x86/mish_x86_avx512bf16.cpp @@ -0,0 +1,19 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "mish_x86.h" + +#include "x86_activation.h" + +#include "cpu.h" + +namespace ncnn { + +#include "mish_bf16s.h" + +void mish_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + mish_bf16s(a, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/quantize_bf16s.h b/src/layer/x86/quantize_bf16s.h new file mode 100644 index 00000000000..48cc64bd68d --- /dev/null +++ b/src/layer/x86/quantize_bf16s.h @@ -0,0 +1,461 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +int quantize_forward_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Option& opt); +#endif + +static void quantize_bf16(const unsigned short* ptr, signed char* s8ptr, const Mat& scale_data, int elemcount, int elempack) +{ + const int scale_data_size = scale_data.w; + const int size = elemcount * elempack; + + float scale = scale_data[0]; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#if __AVX__ + __m256 _scale_avx = _mm256_set1_ps(scale); +#if __AVX512F__ + __m512 _scale_avx512 = _mm512_set1_ps(scale); +#endif // __AVX512F__ +#endif // __AVX__ + if (scale_data_size > 1) + { +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + _scale_avx512 = _mm512_loadu_ps((const float*)scale_data); + } +#endif // __AVX512F__ + if (elempack == 8) + { + _scale_avx = _mm256_loadu_ps((const float*)scale_data); +#if __AVX512F__ + _scale_avx512 = combine8x2_ps(_scale_avx, _scale_avx); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { + _scale = _mm_loadu_ps((const float*)scale_data); +#if __AVX__ + _scale_avx = combine4x2_ps(_scale, _scale); +#if __AVX512F__ + _scale_avx512 = combine8x2_ps(_scale_avx, _scale_avx); +#endif // __AVX512F__ +#endif // __AVX__ + } + } +#endif // __SSE2__ + + int i = 0; +#if __SSE2__ +#if __AVX__ + for (; i + 15 < size; i += 16) + { +#if __AVX512F__ + __m512 _v = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _v = _mm512_mul_ps(_v, _scale_avx512); + _mm_storeu_si128((__m128i*)s8ptr, float2int8_avx512(_v)); +#else // __AVX512F__ + __m256 _v0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 _v1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + 8))); + _v0 = _mm256_mul_ps(_v0, _scale_avx); + _v1 = _mm256_mul_ps(_v1, _scale_avx); + _mm_storeu_si128((__m128i*)s8ptr, float2int8_avx(_v0, _v1)); +#endif // __AVX512F__ + ptr += 16; + s8ptr += 16; + } +#endif // __AVX__ + for (; i + 7 < size; i += 8) + { +#if __AVX__ + __m256 _v = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _v = _mm256_mul_ps(_v, _scale_avx); + *(int64_t*)s8ptr = float2int8_avx(_v); +#else // __AVX__ + __m128 _v0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _v1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 4))); + _v0 = _mm_mul_ps(_v0, _scale); + _v1 = _mm_mul_ps(_v1, _scale); + *(int64_t*)s8ptr = float2int8_sse(_v0, _v1); +#endif // __AVX__ + ptr += 8; + s8ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _v = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _v = _mm_mul_ps(_v, _scale); + int32_t v = float2int8_sse(_v); + s8ptr[0] = (v >> 0) & 0xff; + s8ptr[1] = (v >> 8) & 0xff; + s8ptr[2] = (v >> 16) & 0xff; + s8ptr[3] = (v >> 24) & 0xff; + ptr += 4; + s8ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr) * scale; + *s8ptr = float2int8(v); + ptr++; + s8ptr++; + } +} + +#if __SSE2__ +#if __AVX512F__ +static void quantize_bf16_pack16to8(const unsigned short* ptr, signed char* s8ptr0, signed char* s8ptr1, const Mat& scale_data, int elemcount) +{ + const int scale_data_size = scale_data.w; + + float scale = scale_data[0]; + __m512 _scale = _mm512_set1_ps(scale); + if (scale_data_size > 1) + { + _scale = _mm512_loadu_ps((const float*)scale_data); + } + + int i = 0; + for (; i < elemcount; i++) + { + __m512 _v = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _v = _mm512_mul_ps(_v, _scale); + __m128i v = float2int8_avx512(_v); + _mm_storel_pd((double*)s8ptr0, _mm_castsi128_pd(v)); + _mm_storeh_pd((double*)s8ptr1, _mm_castsi128_pd(v)); + ptr += 16; + s8ptr0 += 8; + s8ptr1 += 8; + } +} +#endif // __AVX512F__ + +#if !__AVX__ +static void quantize_bf16_pack4to8(const unsigned short* ptr0, const unsigned short* ptr1, signed char* s8ptr, const Mat& scale_data, int elemcount) +{ + const int scale_data_size = scale_data.w; + + float scale = scale_data[0]; + __m128 _scale0 = _mm_set1_ps(scale); + __m128 _scale1 = _scale0; + if (scale_data_size > 1) + { + _scale0 = _mm_loadu_ps((const float*)scale_data); + _scale1 = _mm_loadu_ps((const float*)scale_data + 4); + } + + int i = 0; + for (; i + 1 < elemcount; i += 2) + { + __m128 _v0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + __m128 _v1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + __m128 _v2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr0 + 4))); + __m128 _v3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr1 + 4))); + _v0 = _mm_mul_ps(_v0, _scale0); + _v1 = _mm_mul_ps(_v1, _scale1); + _v2 = _mm_mul_ps(_v2, _scale0); + _v3 = _mm_mul_ps(_v3, _scale1); + _mm_storeu_si128((__m128i*)s8ptr, float2int8_sse(_v0, _v1, _v2, _v3)); + ptr0 += 8; + ptr1 += 8; + s8ptr += 16; + } + for (; i < elemcount; i++) + { + __m128 _v0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + __m128 _v1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + _v0 = _mm_mul_ps(_v0, _scale0); + _v1 = _mm_mul_ps(_v1, _scale1); + *(int64_t*)s8ptr = float2int8_sse(_v0, _v1); + ptr0 += 4; + ptr1 += 4; + s8ptr += 8; + } +} +#endif // !__AVX__ + +static void quantize_bf16_pack4to1(const unsigned short* ptr, signed char* s8ptr0, signed char* s8ptr1, signed char* s8ptr2, signed char* s8ptr3, const Mat& scale_data, int elemcount) +{ + const int scale_data_size = scale_data.w; + + float scale = scale_data[0]; + __m128 _scale = _mm_set1_ps(scale); + if (scale_data_size > 1) + { + _scale = _mm_loadu_ps((const float*)scale_data); + } + + int i = 0; + for (; i + 7 < elemcount; i += 8) + { + __m128 _v0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 _v1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 4))); + __m128 _v2 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 8))); + __m128 _v3 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 12))); + __m128 _v4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 16))); + __m128 _v5 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 20))); + __m128 _v6 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 24))); + __m128 _v7 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 28))); + _v0 = _mm_mul_ps(_v0, _scale); + _v1 = _mm_mul_ps(_v1, _scale); + _v2 = _mm_mul_ps(_v2, _scale); + _v3 = _mm_mul_ps(_v3, _scale); + _v4 = _mm_mul_ps(_v4, _scale); + _v5 = _mm_mul_ps(_v5, _scale); + _v6 = _mm_mul_ps(_v6, _scale); + _v7 = _mm_mul_ps(_v7, _scale); + __m128i v0426 = float2int8_sse(_v0, _v4, _v2, _v6); + __m128i v1537 = float2int8_sse(_v1, _v5, _v3, _v7); + __m128i v0145 = _mm_unpacklo_epi8(v0426, v1537); + __m128i v2367 = _mm_unpackhi_epi8(v0426, v1537); + __m128i v0123 = _mm_unpacklo_epi16(v0145, v2367); + __m128i v4567 = _mm_unpackhi_epi16(v0145, v2367); + __m128i v01 = _mm_unpacklo_epi32(v0123, v4567); + __m128i v23 = _mm_unpackhi_epi32(v0123, v4567); + _mm_storel_pd((double*)s8ptr0, _mm_castsi128_pd(v01)); + _mm_storeh_pd((double*)s8ptr1, _mm_castsi128_pd(v01)); + _mm_storel_pd((double*)s8ptr2, _mm_castsi128_pd(v23)); + _mm_storeh_pd((double*)s8ptr3, _mm_castsi128_pd(v23)); + ptr += 32; + s8ptr0 += 8; + s8ptr1 += 8; + s8ptr2 += 8; + s8ptr3 += 8; + } + for (; i < elemcount; i++) + { + __m128 _v = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _v = _mm_mul_ps(_v, _scale); + int64_t v = float2int8_sse(_v, _v); + s8ptr0[0] = (v >> 32) & 0xff; + s8ptr1[0] = (v >> 40) & 0xff; + s8ptr2[0] = (v >> 48) & 0xff; + s8ptr3[0] = (v >> 56) & 0xff; + ptr += 4; + s8ptr0 += 1; + s8ptr1 += 1; + s8ptr2 += 1; + s8ptr3 += 1; + } +} +#endif // __SSE2__ + +static int quantize_forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + return quantize_forward_bf16s_avx512bf16(bottom_blob, top_blob, scale_data, scale_data_size, opt); + } +#endif + + const int dims = bottom_blob.dims; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + if (dims == 1) + { + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { + out_elempack = w * elempack % 8 == 0 ? 8 : 1; + } +#endif + const int outw = w * elempack / out_elempack; + const size_t out_elemsize = out_elempack * 1u; + + top_blob.create(outw, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const int wp = std::max(1, w / opt.num_threads); + const int nn_w = (w + wp - 1) / wp; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_w; ii++) + { + const int i = ii * wp; + + const unsigned short* ptr = (const unsigned short*)bottom_blob + i * elempack; + signed char* s8ptr = (signed char*)top_blob + i * elempack; + + // assert scale_data_size == 1 + + const int size = std::min(w - i, wp) * elempack; + + quantize_bf16(ptr, s8ptr, scale_data, size, 1); + } + } + + if (dims == 2) + { + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { + out_elempack = h * elempack % 8 == 0 ? 8 : 1; + } +#endif + const int outh = h * elempack / out_elempack; + const size_t out_elemsize = out_elempack * 1u; + + top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + +#if __SSE2__ +#if __AVX512F__ + if (elempack == 16 && out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + signed char* s8ptr0 = top_blob.row(i * 2); + signed char* s8ptr1 = top_blob.row(i * 2 + 1); + + const Mat scale_data_i = scale_data_size > 1 ? scale_data.range(i * elempack, elempack) : scale_data; + + quantize_bf16_pack16to8(ptr, s8ptr0, s8ptr1, scale_data_i, w); + } + } +#endif // __AVX512F__ +#if !__AVX__ + if (elempack == 4 && out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < outh; i++) + { + const unsigned short* ptr0 = bottom_blob.row(i * 2); + const unsigned short* ptr1 = bottom_blob.row(i * 2 + 1); + signed char* s8ptr = top_blob.row(i); + + const Mat scale_data_i = scale_data_size > 1 ? scale_data.range(i * out_elempack, out_elempack) : scale_data; + + quantize_bf16_pack4to8(ptr0, ptr1, s8ptr, scale_data_i, w); + } + } +#endif // !__AVX__ + if (elempack == 4 && out_elempack == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + signed char* s8ptr0 = top_blob.row(i * 4); + signed char* s8ptr1 = top_blob.row(i * 4 + 1); + signed char* s8ptr2 = top_blob.row(i * 4 + 2); + signed char* s8ptr3 = top_blob.row(i * 4 + 3); + + const Mat scale_data_i = scale_data_size > 1 ? scale_data.range(i * elempack, elempack) : scale_data; + + quantize_bf16_pack4to1(ptr, s8ptr0, s8ptr1, s8ptr2, s8ptr3, scale_data_i, w); + } + } +#endif // __SSE2__ + if (elempack == out_elempack) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + const unsigned short* ptr = bottom_blob.row(i); + signed char* s8ptr = top_blob.row(i); + + const Mat scale_data_i = scale_data_size > 1 ? scale_data.range(i * elempack, elempack) : scale_data; + + quantize_bf16(ptr, s8ptr, scale_data_i, w, elempack); + } + } + } + + if (dims == 3) + { + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { + out_elempack = channels * elempack % 8 == 0 ? 8 : 1; + } +#endif + const int outc = channels * elempack / out_elempack; + const size_t out_elemsize = out_elempack * 1u; + + top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + +#if __SSE2__ +#if __AVX512F__ + if (elempack == 16 && out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + signed char* s8ptr0 = top_blob.channel(q * 2); + signed char* s8ptr1 = top_blob.channel(q * 2 + 1); + + const Mat scale_data_q = scale_data_size > 1 ? scale_data.range(q * elempack, elempack) : scale_data; + + quantize_bf16_pack16to8(ptr, s8ptr0, s8ptr1, scale_data_q, w * h); + } + } +#endif // __AVX512F__ +#if !__AVX__ + if (elempack == 4 && out_elempack == 8) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < outc; q++) + { + const unsigned short* ptr0 = bottom_blob.channel(q * 2); + const unsigned short* ptr1 = bottom_blob.channel(q * 2 + 1); + signed char* s8ptr = top_blob.channel(q); + + const Mat scale_data_q = scale_data_size > 1 ? scale_data.range(q * out_elempack, out_elempack) : scale_data; + + quantize_bf16_pack4to8(ptr0, ptr1, s8ptr, scale_data_q, w * h); + } + } +#endif // !__AVX__ + if (elempack == 4 && out_elempack == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + signed char* s8ptr0 = top_blob.channel(q * 4); + signed char* s8ptr1 = top_blob.channel(q * 4 + 1); + signed char* s8ptr2 = top_blob.channel(q * 4 + 2); + signed char* s8ptr3 = top_blob.channel(q * 4 + 3); + + const Mat scale_data_q = scale_data_size > 1 ? scale_data.range(q * elempack, elempack) : scale_data; + + quantize_bf16_pack4to1(ptr, s8ptr0, s8ptr1, s8ptr2, s8ptr3, scale_data_q, w * h); + } + } +#endif // __SSE2__ + if (elempack == out_elempack) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = bottom_blob.channel(q); + signed char* s8ptr = top_blob.channel(q); + + const Mat scale_data_q = scale_data_size > 1 ? scale_data.range(q * elempack, elempack) : scale_data; + + quantize_bf16(ptr, s8ptr, scale_data_q, w * h, elempack); + } + } + } + + return 0; +} diff --git a/src/layer/x86/quantize_x86.cpp b/src/layer/x86/quantize_x86.cpp index f91c5dcf4c0..46c87831f9a 100644 --- a/src/layer/x86/quantize_x86.cpp +++ b/src/layer/x86/quantize_x86.cpp @@ -12,13 +12,22 @@ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "quantize_bf16s.h" +#endif + Quantize_x86::Quantize_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } static void quantize(const float* ptr, signed char* s8ptr, const Mat& scale_data, int elemcount, int elempack) @@ -271,6 +280,11 @@ static void quantize_pack4to1(const float* ptr, signed char* s8ptr0, signed char int Quantize_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_blob.elembits() == 16) + return forward_bf16s(bottom_blob, top_blob, opt); +#endif + const int dims = bottom_blob.dims; const int w = bottom_blob.w; const int h = bottom_blob.h; @@ -477,4 +491,11 @@ int Quantize_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o return 0; } +#if NCNN_BF16 +int Quantize_x86::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + return quantize_forward_bf16s(bottom_blob, top_blob, scale_data, scale_data_size, opt); +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/quantize_x86.h b/src/layer/x86/quantize_x86.h index ac3631cce7b..8dd13ebdc20 100644 --- a/src/layer/x86/quantize_x86.h +++ b/src/layer/x86/quantize_x86.h @@ -14,6 +14,11 @@ class Quantize_x86 : public Quantize Quantize_x86(); virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/quantize_x86_avx512bf16.cpp b/src/layer/x86/quantize_x86_avx512bf16.cpp new file mode 100644 index 00000000000..1b3e617be69 --- /dev/null +++ b/src/layer/x86/quantize_x86_avx512bf16.cpp @@ -0,0 +1,27 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "quantize_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "quantize_bf16s.h" + +int quantize_forward_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_data, int scale_data_size, const Option& opt) +{ + return quantize_forward_bf16s(bottom_blob, top_blob, scale_data, scale_data_size, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/rotaryembed_bf16s.h b/src/layer/x86/rotaryembed_bf16s.h new file mode 100644 index 00000000000..cda65daadbe --- /dev/null +++ b/src/layer/x86/rotaryembed_bf16s.h @@ -0,0 +1,364 @@ +// Copyright 2026 pchar.cn +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void rotaryembed_bf16s_avx512bf16(const Mat& bottom_blob, const Mat& cos_cache, const Mat& sin_cache, Mat& top_blob, int interleaved, const Option& opt); +#endif + +static void rotaryembed_bf16s(const Mat& bottom_blob, const Mat& cos_cache, const Mat& sin_cache, Mat& top_blob, int interleaved, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + rotaryembed_bf16s_avx512bf16(bottom_blob, cos_cache, sin_cache, top_blob, interleaved, opt); + return; + } +#endif + + const int embed_dim = bottom_blob.w; + const int seqlen = bottom_blob.h; + const int num_heads = bottom_blob.c; + + top_blob.create_like(bottom_blob, opt.blob_allocator); + if (top_blob.empty()) + return; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + const Mat head = bottom_blob.channel(q); + Mat out_head = top_blob.channel(q); + + for (int i = 0; i < seqlen; i++) + { + if (interleaved) + { + const unsigned short* ptr = head.row(i); + const unsigned short* cos_ptr = cos_cache.row(i); + const unsigned short* sin_ptr = sin_cache.row(i); + unsigned short* outptr = out_head.row(i); + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + const __m512i dupidx = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + const __m512i dupidx_lo = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + for (; j + 15 < embed_dim / 2; j += 16) + { + __m512 a0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + __m512 a1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(ptr + 16))); + + __m512 cs_src = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)cos_ptr)); + __m512 ss_src = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)sin_ptr)); + + __m512 c0 = _mm512_permutexvar_ps(dupidx_lo, cs_src); + __m512 c1 = _mm512_permutexvar_ps(dupidx, cs_src); + __m512 s0 = _mm512_permutexvar_ps(dupidx_lo, ss_src); + __m512 s1 = _mm512_permutexvar_ps(dupidx, ss_src); + + __m512 swap0 = _mm512_shuffle_ps(a0, a0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 swap1 = _mm512_shuffle_ps(a1, a1, _MM_SHUFFLE(2, 3, 0, 1)); + + __m512 ss0 = _mm512_mul_ps(swap0, s0); + __m512 ss1 = _mm512_mul_ps(swap1, s1); + + __m512 y0 = _mm512_fmaddsub_ps(a0, c0, ss0); + __m512 y1 = _mm512_fmaddsub_ps(a1, c1, ss1); + + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(y0)); + _mm256_storeu_si256((__m256i*)(outptr + 16), float2bfloat_avx512(y1)); + + ptr += 32; + outptr += 32; + cos_ptr += 16; + sin_ptr += 16; + } +#endif // __AVX512F__ +#if __AVX2__ + const __m256i dupidx256 = _mm256_set_epi32(7, 7, 6, 6, 5, 5, 4, 4); + const __m256i dupidx256_lo = _mm256_set_epi32(3, 3, 2, 2, 1, 1, 0, 0); + for (; j + 7 < embed_dim / 2; j += 8) + { + __m256 a0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 a1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + 8))); + + __m256 c_src = bfloat2float_avx(_mm_loadu_si128((const __m128i*)cos_ptr)); + __m256 s_src = bfloat2float_avx(_mm_loadu_si128((const __m128i*)sin_ptr)); + + __m256 c0 = _mm256_permutevar8x32_ps(c_src, dupidx256_lo); + __m256 c1 = _mm256_permutevar8x32_ps(c_src, dupidx256); + __m256 s0 = _mm256_permutevar8x32_ps(s_src, dupidx256_lo); + __m256 s1 = _mm256_permutevar8x32_ps(s_src, dupidx256); + + __m256 swap0 = _mm256_shuffle_ps(a0, a0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 swap1 = _mm256_shuffle_ps(a1, a1, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256 ss0 = _mm256_mul_ps(swap0, s0); + __m256 ss1 = _mm256_mul_ps(swap1, s1); + + __m256 y0 = _mm256_fmaddsub_ps(a0, c0, ss0); + __m256 y1 = _mm256_fmaddsub_ps(a1, c1, ss1); + + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(y0)); + _mm_storeu_si128((__m128i*)(outptr + 8), float2bfloat_avx(y1)); + + ptr += 16; + outptr += 16; + cos_ptr += 8; + sin_ptr += 8; + } +#else // __AVX2__ + for (; j + 7 < embed_dim / 2; j += 8) + { + __m256 a0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + __m256 a1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(ptr + 8))); + + __m128 clo4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)cos_ptr)); + __m128 chi4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(cos_ptr + 4))); + __m128 slo4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sin_ptr)); + __m128 shi4 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(sin_ptr + 4))); + + __m128 clo_lo = _mm_unpacklo_ps(clo4, clo4); // [c0,c0,c1,c1] + __m128 clo_hi = _mm_unpackhi_ps(clo4, clo4); // [c2,c2,c3,c3] + __m128 chi_lo = _mm_unpacklo_ps(chi4, chi4); // [c4,c4,c5,c5] + __m128 chi_hi = _mm_unpackhi_ps(chi4, chi4); // [c6,c6,c7,c7] + + __m256 c0 = combine4x2_ps(clo_lo, clo_hi); + __m256 c1 = combine4x2_ps(chi_lo, chi_hi); + + __m128 slo_lo = _mm_unpacklo_ps(slo4, slo4); // [s0,s0,s1,s1] + __m128 slo_hi = _mm_unpackhi_ps(slo4, slo4); // [s2,s2,s3,s3] + __m128 shi_lo = _mm_unpacklo_ps(shi4, shi4); // [s4,s4,s5,s5] + __m128 shi_hi = _mm_unpackhi_ps(shi4, shi4); // [s6,s6,s7,s7] + + __m256 s0 = combine4x2_ps(slo_lo, slo_hi); + __m256 s1 = combine4x2_ps(shi_lo, shi_hi); + + __m256 swap0 = _mm256_shuffle_ps(a0, a0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 swap1 = _mm256_shuffle_ps(a1, a1, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256 ss0 = _mm256_mul_ps(swap0, s0); + __m256 ss1 = _mm256_mul_ps(swap1, s1); + +#if __FMA__ + __m256 y0 = _mm256_fmaddsub_ps(a0, c0, ss0); + __m256 y1 = _mm256_fmaddsub_ps(a1, c1, ss1); +#else + __m256 ac0 = _mm256_mul_ps(a0, c0); + __m256 ac1 = _mm256_mul_ps(a1, c1); + + __m256 y0 = _mm256_addsub_ps(ac0, ss0); + __m256 y1 = _mm256_addsub_ps(ac1, ss1); +#endif + _mm_storeu_si128((__m128i*)outptr, float2bfloat_avx(y0)); + _mm_storeu_si128((__m128i*)(outptr + 8), float2bfloat_avx(y1)); + + ptr += 16; + outptr += 16; + cos_ptr += 8; + sin_ptr += 8; + } +#endif // __AVX2__ +#endif // __AVX__ + for (; j + 3 < embed_dim / 2; j += 4) + { + __m128 a0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + __m128 a1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(ptr + 4))); + + __m128i c4_raw = _mm_loadl_epi64((const __m128i*)cos_ptr); + __m128i s4_raw = _mm_loadl_epi64((const __m128i*)sin_ptr); + __m128 c4 = bfloat2float_sse(c4_raw); + __m128 s4 = bfloat2float_sse(s4_raw); + + __m128 clo = _mm_unpacklo_ps(c4, c4); // [c0,c0,c1,c1] + __m128 chi = _mm_unpackhi_ps(c4, c4); // [c2,c2,c3,c3] + __m128 slo = _mm_unpacklo_ps(s4, s4); // [s0,s0,s1,s1] + __m128 shi = _mm_unpackhi_ps(s4, s4); // [s2,s2,s3,s3] + + __m128 swap0 = _mm_shuffle_ps(a0, a0, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 swap1 = _mm_shuffle_ps(a1, a1, _MM_SHUFFLE(2, 3, 0, 1)); + + __m128 ss0 = _mm_mul_ps(swap0, slo); + __m128 ss1 = _mm_mul_ps(swap1, shi); +#if __FMA__ + __m128 y0 = _mm_fmaddsub_ps(a0, clo, ss0); + __m128 y1 = _mm_fmaddsub_ps(a1, chi, ss1); +#else + __m128 ac0 = _mm_mul_ps(a0, clo); + __m128 ac1 = _mm_mul_ps(a1, chi); +#if __SSE3__ + __m128 y0 = _mm_addsub_ps(ac0, ss0); + __m128 y1 = _mm_addsub_ps(ac1, ss1); +#else +#if defined(__MINGW32__) && !defined(__x86_64__) + __attribute__((aligned(16))) + const float signmask128_array[4] + = {-0.f, 0.f, -0.f, 0.f}; + const __m128 signmask128 = _mm_load_ps(signmask128_array); +#else + const __m128 signmask128 = _mm_set_ps(0.f, -0.f, 0.f, -0.f); +#endif + ss0 = _mm_xor_ps(ss0, signmask128); + ss1 = _mm_xor_ps(ss1, signmask128); + __m128 y0 = _mm_add_ps(ac0, ss0); + __m128 y1 = _mm_add_ps(ac1, ss1); +#endif +#endif + __m128i y01_bf16 = float2bfloat_sse(y0, y1); + _mm_storeu_si128((__m128i*)outptr, y01_bf16); + + ptr += 8; + outptr += 8; + cos_ptr += 4; + sin_ptr += 4; + } + for (; j + 1 < embed_dim / 2; j += 2) + { + __m128i a_raw = _mm_loadl_epi64((const __m128i*)ptr); + __m128 a = bfloat2float_sse(a_raw); + + float cos0 = bfloat16_to_float32(cos_ptr[0]); + float cos1 = bfloat16_to_float32(cos_ptr[1]); + float sin0 = bfloat16_to_float32(sin_ptr[0]); + float sin1 = bfloat16_to_float32(sin_ptr[1]); + + __m128 c = _mm_set_ps(cos1, cos1, cos0, cos0); + __m128 s = _mm_set_ps(sin1, sin1, sin0, sin0); + + __m128 swap = _mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 ss = _mm_mul_ps(swap, s); + +#if __FMA__ + __m128 y = _mm_fmaddsub_ps(a, c, ss); +#else + __m128 ac = _mm_mul_ps(a, c); +#if __SSE3__ + __m128 y = _mm_addsub_ps(ac, ss); +#else +#if defined(__MINGW32__) && !defined(__x86_64__) + __attribute__((aligned(16))) + const float signmask128_array[4] + = {-0.f, 0.f, -0.f, 0.f}; + const __m128 signmask128 = _mm_load_ps(signmask128_array); +#else + const __m128 signmask128 = _mm_set_ps(0.f, -0.f, 0.f, -0.f); +#endif + ss = _mm_xor_ps(ss, signmask128); + __m128 y = _mm_add_ps(ac, ss); +#endif +#endif + __m128i y_bf16 = float2bfloat_sse(y, y); + _mm_storel_epi64((__m128i*)outptr, y_bf16); + + ptr += 4; + outptr += 4; + cos_ptr += 2; + sin_ptr += 2; + } +#endif // __SSE2__ + for (; j < embed_dim / 2; j++) + { + const float x0 = bfloat16_to_float32(ptr[0]); + const float x1 = bfloat16_to_float32(ptr[1]); + const float cos_val = bfloat16_to_float32(*cos_ptr++); + const float sin_val = bfloat16_to_float32(*sin_ptr++); + + outptr[0] = float32_to_bfloat16(x0 * cos_val - x1 * sin_val); + outptr[1] = float32_to_bfloat16(x0 * sin_val + x1 * cos_val); + + ptr += 2; + outptr += 2; + } + } + else + { + const unsigned short* ptr0 = head.row(i); + const unsigned short* ptr1 = ptr0 + embed_dim / 2; + const unsigned short* cos_ptr = cos_cache.row(i); + const unsigned short* sin_ptr = sin_cache.row(i); + + unsigned short* outptr0 = out_head.row(i); + unsigned short* outptr1 = outptr0 + embed_dim / 2; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < embed_dim / 2; j += 16) + { + __m512 x0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr0)); + __m512 x1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr1)); + __m512 c = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)cos_ptr)); + __m512 s = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)sin_ptr)); + + __m512 y0 = _mm512_fnmadd_ps(x1, s, _mm512_mul_ps(x0, c)); + __m512 y1 = _mm512_fmadd_ps(x0, s, _mm512_mul_ps(x1, c)); + + _mm256_storeu_si256((__m256i*)outptr0, float2bfloat_avx512(y0)); + _mm256_storeu_si256((__m256i*)outptr1, float2bfloat_avx512(y1)); + + ptr0 += 16; + ptr1 += 16; + cos_ptr += 16; + sin_ptr += 16; + outptr0 += 16; + outptr1 += 16; + } +#endif // __AVX512F__ + for (; j + 7 < embed_dim / 2; j += 8) + { + __m256 x0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr0)); + __m256 x1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr1)); + __m256 c = bfloat2float_avx(_mm_loadu_si128((const __m128i*)cos_ptr)); + __m256 s = bfloat2float_avx(_mm_loadu_si128((const __m128i*)sin_ptr)); + + __m256 y0 = _mm256_comp_fnmadd_ps(x1, s, _mm256_mul_ps(x0, c)); + __m256 y1 = _mm256_comp_fmadd_ps(x0, s, _mm256_mul_ps(x1, c)); + + _mm_storeu_si128((__m128i*)outptr0, float2bfloat_avx(y0)); + _mm_storeu_si128((__m128i*)outptr1, float2bfloat_avx(y1)); + + ptr0 += 8; + ptr1 += 8; + cos_ptr += 8; + sin_ptr += 8; + outptr0 += 8; + outptr1 += 8; + } +#endif // __AVX__ + for (; j + 3 < embed_dim / 2; j += 4) + { + __m128 x0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr0)); + __m128 x1 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr1)); + __m128 c = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)cos_ptr)); + __m128 s = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sin_ptr)); + + __m128 y0 = _mm_comp_fnmadd_ps(x1, s, _mm_mul_ps(x0, c)); + __m128 y1 = _mm_comp_fmadd_ps(x0, s, _mm_mul_ps(x1, c)); + + _mm_storel_epi64((__m128i*)outptr0, float2bfloat_sse(y0, y0)); + _mm_storel_epi64((__m128i*)outptr1, float2bfloat_sse(y1, y1)); + + ptr0 += 4; + ptr1 += 4; + cos_ptr += 4; + sin_ptr += 4; + outptr0 += 4; + outptr1 += 4; + } +#endif // __SSE2__ + for (; j < embed_dim / 2; j++) + { + const float x0 = bfloat16_to_float32(*ptr0++); + const float x1 = bfloat16_to_float32(*ptr1++); + const float cos_val = bfloat16_to_float32(*cos_ptr++); + const float sin_val = bfloat16_to_float32(*sin_ptr++); + + *outptr0++ = float32_to_bfloat16(x0 * cos_val - x1 * sin_val); + *outptr1++ = float32_to_bfloat16(x0 * sin_val + x1 * cos_val); + } + } + } + } +} diff --git a/src/layer/x86/rotaryembed_x86.cpp b/src/layer/x86/rotaryembed_x86.cpp index 75ac77e3a20..982af51df3e 100644 --- a/src/layer/x86/rotaryembed_x86.cpp +++ b/src/layer/x86/rotaryembed_x86.cpp @@ -15,14 +15,28 @@ #include "x86_usability.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "rotaryembed_bf16s.h" +#endif + RotaryEmbed_x86::RotaryEmbed_x86() { +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int RotaryEmbed_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_blobs[0].elembits() == 16) + return forward_bf16s(bottom_blobs, top_blobs, opt); +#endif + const Mat& bottom_blob = bottom_blobs[0]; const Mat& cos_cache = bottom_blobs[1]; const Mat& sin_cache = bottom_blobs[2]; @@ -366,4 +380,19 @@ int RotaryEmbed_x86::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& cos_cache = bottom_blobs[1]; + const Mat& sin_cache = bottom_blobs[2]; + + Mat& top_blob = top_blobs[0]; + + rotaryembed_bf16s(bottom_blob, cos_cache, sin_cache, top_blob, interleaved, opt); + + return top_blob.empty() ? -100 : 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/rotaryembed_x86.h b/src/layer/x86/rotaryembed_x86.h index 1015ff351cf..8bbfc239e92 100644 --- a/src/layer/x86/rotaryembed_x86.h +++ b/src/layer/x86/rotaryembed_x86.h @@ -14,6 +14,11 @@ class RotaryEmbed_x86 : public RotaryEmbed RotaryEmbed_x86(); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/rotaryembed_x86_avx512bf16.cpp b/src/layer/x86/rotaryembed_x86_avx512bf16.cpp new file mode 100644 index 00000000000..c1f5904c046 --- /dev/null +++ b/src/layer/x86/rotaryembed_x86_avx512bf16.cpp @@ -0,0 +1,17 @@ +// Copyright 2026 pchar.cn +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "rotaryembed_bf16s.h" + +void rotaryembed_bf16s_avx512bf16(const Mat& bottom_blob, const Mat& cos_cache, const Mat& sin_cache, Mat& top_blob, int interleaved, const Option& opt) +{ + rotaryembed_bf16s(bottom_blob, cos_cache, sin_cache, top_blob, interleaved, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/selu_bf16s.h b/src/layer/x86/selu_bf16s.h new file mode 100644 index 00000000000..26e8d03e2f3 --- /dev/null +++ b/src/layer/x86/selu_bf16s.h @@ -0,0 +1,108 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void selu_bf16s_avx512bf16(Mat& a, float alphaxlambda, float lambda, const Option& opt); +#endif + +static void selu_bf16s(Mat& a, float alphaxlambda, float lambda, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + selu_bf16s_avx512bf16(a, alphaxlambda, lambda, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _zero512 = _mm512_setzero_ps(); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _alpha512 = _mm512_set1_ps(alphaxlambda / lambda); + __m512 _lambda512 = _mm512_set1_ps(lambda); + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + + __m512 _pos = _mm512_max_ps(_zero512, _p); + __m512 _neg = _mm512_min_ps(_zero512, _p); + + __m512 _blob = exp512_ps(_neg); + _blob = _mm512_sub_ps(_blob, _one512); + _blob = _mm512_mul_ps(_alpha512, _blob); + _blob = _mm512_mul_ps(_lambda512, _mm512_add_ps(_pos, _blob)); + + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_blob)); + + ptr += 16; + } +#endif // __AVX512F__ + __m256 _zero256 = _mm256_setzero_ps(); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _alpha256 = _mm256_set1_ps(alphaxlambda / lambda); + __m256 _lambda256 = _mm256_set1_ps(lambda); + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + + __m256 _pos = _mm256_max_ps(_zero256, _p); + __m256 _neg = _mm256_min_ps(_zero256, _p); + + __m256 _blob = exp256_ps(_neg); + _blob = _mm256_sub_ps(_blob, _one256); + _blob = _mm256_mul_ps(_alpha256, _blob); + _blob = _mm256_mul_ps(_lambda256, _mm256_add_ps(_pos, _blob)); + + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_blob)); + + ptr += 8; + } +#endif // __AVX__ + __m128 _zero128 = _mm_setzero_ps(); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _alpha128 = _mm_set1_ps(alphaxlambda / lambda); + __m128 _lambda128 = _mm_set1_ps(lambda); + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + + __m128 _pos = _mm_max_ps(_zero128, _p); + __m128 _neg = _mm_min_ps(_zero128, _p); + + __m128 _blob = exp_ps(_neg); + _blob = _mm_sub_ps(_blob, _one128); + _blob = _mm_mul_ps(_alpha128, _blob); + _blob = _mm_mul_ps(_lambda128, _mm_add_ps(_pos, _blob)); + + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_blob, _blob)); + + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + if (v < 0) + v = (expf(v) - 1.f) * alphaxlambda; + else + v = v * lambda; + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/selu_x86.cpp b/src/layer/x86/selu_x86.cpp index 1a39a864b42..8d7c2aa3fc4 100644 --- a/src/layer/x86/selu_x86.cpp +++ b/src/layer/x86/selu_x86.cpp @@ -15,13 +15,24 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "selu_bf16s.h" +#endif + SELU_x86::SELU_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int SELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -33,6 +44,11 @@ int SELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int channels = bottom_top_blob.c; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -121,4 +137,14 @@ int SELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int SELU_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + float alphaxlambda = alpha * lambda; + selu_bf16s(bottom_top_blob, alphaxlambda, lambda, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/selu_x86.h b/src/layer/x86/selu_x86.h index 59b85b8101d..35974ce9b04 100644 --- a/src/layer/x86/selu_x86.h +++ b/src/layer/x86/selu_x86.h @@ -14,6 +14,11 @@ class SELU_x86 : public SELU SELU_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/selu_x86_avx512bf16.cpp b/src/layer/x86/selu_x86_avx512bf16.cpp new file mode 100644 index 00000000000..37c0362e08b --- /dev/null +++ b/src/layer/x86/selu_x86_avx512bf16.cpp @@ -0,0 +1,32 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "selu_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +#include "cpu.h" +#include "mat.h" + +namespace ncnn { + +#include "selu_bf16s.h" + +void selu_bf16s_avx512bf16(Mat& a, float alphaxlambda, float lambda, const Option& opt) +{ + selu_bf16s(a, alphaxlambda, lambda, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/tanh_bf16s.h b/src/layer/x86/tanh_bf16s.h new file mode 100644 index 00000000000..c161626bdc2 --- /dev/null +++ b/src/layer/x86/tanh_bf16s.h @@ -0,0 +1,84 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void tanh_bf16s_avx512bf16(Mat& a, const Option& opt); +#endif + +static void tanh_bf16s(Mat& a, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + tanh_bf16s_avx512bf16(a, opt); + return; + } +#endif + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = a.channel(q); + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)); + _p = tanh_avx512(_p); + _mm256_storeu_si256((__m256i*)ptr, float2bfloat_avx512(_p)); + ptr += 16; + } + if (i < size) + { + const unsigned int remain = size - i; + __mmask16 _mask = (__mmask16)((1u << remain) - 1); + __m512 _p = bfloat2float_avx512(_mm256_maskz_loadu_epi16(_mask, ptr)); + _p = tanh_avx512(_p); + _mm256_mask_storeu_epi16(ptr, _mask, float2bfloat_avx512(_p)); + i += remain; + } +#else // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)); + _p = tanh_avx(_p); + _mm_storeu_si128((__m128i*)ptr, float2bfloat_avx(_p)); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = tanh_sse(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX512F__ +#else // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)); + _p = tanh_sse(_p); + _mm_storel_epi64((__m128i*)ptr, float2bfloat_sse(_p, _p)); + ptr += 4; + } +#endif // __AVX__ +#endif // __SSE2__ + for (; i < size; i++) + { + float v = bfloat16_to_float32(*ptr); + v = tanhf(v); + *ptr = float32_to_bfloat16(v); + ptr++; + } + } +} diff --git a/src/layer/x86/tanh_x86.cpp b/src/layer/x86/tanh_x86.cpp index 1ae5a0f28b6..9c846804237 100644 --- a/src/layer/x86/tanh_x86.cpp +++ b/src/layer/x86/tanh_x86.cpp @@ -5,13 +5,22 @@ #include "x86_activation.h" +#include "cpu.h" + namespace ncnn { +#if NCNN_BF16 +#include "tanh_bf16s.h" +#endif + TanH_x86::TanH_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif } int TanH_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -23,6 +32,11 @@ int TanH_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int elempack = bottom_top_blob.elempack; int size = w * h * d * elempack; +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -75,4 +89,13 @@ int TanH_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return 0; } +#if NCNN_BF16 +int TanH_x86::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + tanh_bf16s(bottom_top_blob, opt); + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/tanh_x86.h b/src/layer/x86/tanh_x86.h index 06c28f4ae46..c64ed010e42 100644 --- a/src/layer/x86/tanh_x86.h +++ b/src/layer/x86/tanh_x86.h @@ -14,6 +14,11 @@ class TanH_x86 : public TanH TanH_x86(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/tanh_x86_avx512bf16.cpp b/src/layer/x86/tanh_x86_avx512bf16.cpp new file mode 100644 index 00000000000..50684126b16 --- /dev/null +++ b/src/layer/x86/tanh_x86_avx512bf16.cpp @@ -0,0 +1,19 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "tanh_x86.h" + +#include "x86_activation.h" + +#include "cpu.h" + +namespace ncnn { + +#include "tanh_bf16s.h" + +void tanh_bf16s_avx512bf16(Mat& a, const Option& opt) +{ + tanh_bf16s(a, opt); +} + +} // namespace ncnn From f6a11a564c298fb63842efcef9b8cb5c15de5a95 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 1 Apr 2026 15:33:45 +0800 Subject: [PATCH 32/36] deconvolution x86 support bf16 storage, clean includes (#6627) --- src/layer/x86/binaryop_x86_avx512bf16.cpp | 17 +- src/layer/x86/bnll_x86_avx512bf16.cpp | 20 +- src/layer/x86/clip_x86_avx512bf16.cpp | 5 +- src/layer/x86/deconvolution_packed_bf16s.h | 2541 +++++++++++++++++ src/layer/x86/deconvolution_x86.cpp | 106 + src/layer/x86/deconvolution_x86.h | 6 + .../x86/deconvolution_x86_avx512bf16.cpp | 25 + src/layer/x86/dequantize_x86_avx512bf16.cpp | 12 +- src/layer/x86/dropout_x86_avx512bf16.cpp | 12 +- src/layer/x86/eltwise_x86_avx512bf16.cpp | 12 +- src/layer/x86/elu_x86_avx512bf16.cpp | 9 +- src/layer/x86/erf_x86_avx512bf16.cpp | 20 +- src/layer/x86/gelu_x86_avx512bf16.cpp | 20 +- src/layer/x86/hardsigmoid_x86_avx512bf16.cpp | 5 +- src/layer/x86/hardswish_x86_avx512bf16.cpp | 5 +- src/layer/x86/mish_x86_avx512bf16.cpp | 9 +- src/layer/x86/quantize_x86_avx512bf16.cpp | 12 +- src/layer/x86/relu_x86_avx512bf16.cpp | 5 +- src/layer/x86/selu_x86_avx512bf16.cpp | 20 +- src/layer/x86/sigmoid_x86_avx512bf16.cpp | 20 +- src/layer/x86/softmax_x86_avx512bf16.cpp | 20 +- src/layer/x86/swish_x86_avx512bf16.cpp | 20 +- src/layer/x86/tanh_x86_avx512bf16.cpp | 9 +- src/layer/x86/unaryop_x86_avx512bf16.cpp | 21 +- 24 files changed, 2738 insertions(+), 213 deletions(-) create mode 100644 src/layer/x86/deconvolution_packed_bf16s.h create mode 100644 src/layer/x86/deconvolution_x86_avx512bf16.cpp diff --git a/src/layer/x86/binaryop_x86_avx512bf16.cpp b/src/layer/x86/binaryop_x86_avx512bf16.cpp index 6f6ebba76b4..672f66b25d0 100644 --- a/src/layer/x86/binaryop_x86_avx512bf16.cpp +++ b/src/layer/x86/binaryop_x86_avx512bf16.cpp @@ -3,18 +3,11 @@ #include "binaryop_x86.h" -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - +#include "cpu.h" +#include "layer.h" +#include "layer_type.h" +#include "mat.h" +#include "x86_activation.h" #include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/bnll_x86_avx512bf16.cpp b/src/layer/x86/bnll_x86_avx512bf16.cpp index 15d0d3b58bd..7c28e4ebd9c 100644 --- a/src/layer/x86/bnll_x86_avx512bf16.cpp +++ b/src/layer/x86/bnll_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "bnll_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/clip_x86_avx512bf16.cpp b/src/layer/x86/clip_x86_avx512bf16.cpp index 3641ef761c8..9c66e1da038 100644 --- a/src/layer/x86/clip_x86_avx512bf16.cpp +++ b/src/layer/x86/clip_x86_avx512bf16.cpp @@ -1,12 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "clip_x86.h" - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/deconvolution_packed_bf16s.h b/src/layer/x86/deconvolution_packed_bf16s.h new file mode 100644 index 00000000000..a14171aa81b --- /dev/null +++ b/src/layer/x86/deconvolution_packed_bf16s.h @@ -0,0 +1,2541 @@ +// Copyright 2022 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void deconvolution_transform_kernel_packed_bf16s_avx512bf16(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h); +#endif + +static void deconvolution_transform_kernel_packed_bf16s(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + deconvolution_transform_kernel_packed_bf16s_avx512bf16(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h); + return; + } +#endif + const int maxk = kernel_w * kernel_h; + + // src = maxk-inch-outch + // dst = pb-pa-maxk-inch/pa-outch/pb + + // transpose kernel (reverse k order for deconvolution) + Mat weight_data_transposed(weight_data.w); + { + float* pt = weight_data_transposed; + const float* p = weight_data; + + for (int i = 0; i < num_input * num_output; i++) + { + for (int k = 0; k < maxk; k++) + { + pt[maxk - 1 - k] = p[k]; + } + + p += maxk; + pt += maxk; + } + } + + // src = kw-kh-inch-outch + // dst = pb-pa-kw-kh-inch/pa-outch/pb + Mat weight_data_r2 = weight_data_transposed.reshape(maxk, num_input, num_output); + + // clang-format off + // *INDENT-OFF* +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (num_output >= 16) + { + if (num_input >= 16) + weight_data_tm.create(16 * 16 * maxk, num_input / 16 + (num_input % 16) / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 16 + (num_output % 16) / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 8) + weight_data_tm.create(16 * 8 * maxk, num_input / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 16 + (num_output % 16) / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 4) + weight_data_tm.create(16 * 4 * maxk, num_input / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 16 + (num_output % 16) / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 2) + weight_data_tm.create(16 * 2 * maxk, num_input / 2 + num_input % 2, num_output / 16 + (num_output % 16) / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else + weight_data_tm.create(16 * maxk, num_input, num_output / 16 + (num_output % 16) / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + } + else +#endif // __AVX512F__ + if (num_output >= 8) + { +#if __AVX512F__ + if (num_input >= 16) + weight_data_tm.create(8 * 16 * maxk, num_input / 16 + (num_input % 16) / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else +#endif // __AVX512F__ + if (num_input >= 8) + weight_data_tm.create(8 * 8 * maxk, num_input / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 4) + weight_data_tm.create(8 * 4 * maxk, num_input / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 2) + weight_data_tm.create(8 * 2 * maxk, num_input / 2 + num_input % 2, num_output / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else + weight_data_tm.create(8 * maxk, num_input, num_output / 8 + (num_output % 8) / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + } + else +#endif // __AVX__ + if (num_output >= 4) + { +#if __AVX__ +#if __AVX512F__ + if (num_input >= 16) + weight_data_tm.create(4 * 16 * maxk, num_input / 16 + (num_input % 16) / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else +#endif // __AVX512F__ + if (num_input >= 8) + weight_data_tm.create(4 * 8 * maxk, num_input / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else +#endif // __AVX__ + if (num_input >= 4) + weight_data_tm.create(4 * 4 * maxk, num_input / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else if (num_input >= 2) + weight_data_tm.create(4 * 2 * maxk, num_input / 2 + num_input % 2, num_output / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + else + weight_data_tm.create(4 * maxk, num_input, num_output / 4 + (num_output % 4) / 2 + num_output % 2, (size_t)2u); + } + else +#endif // __SSE2__ + if (num_output >= 2) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (num_input >= 16) + weight_data_tm.create(2 * 16 * maxk, num_input / 16 + (num_input % 16) / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 2 + num_output % 2, (size_t)2u); + else +#endif // __AVX512F__ + if (num_input >= 8) + weight_data_tm.create(2 * 8 * maxk, num_input / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 2 + num_output % 2, (size_t)2u); + else +#endif // __AVX__ + if (num_input >= 4) + weight_data_tm.create(2 * 4 * maxk, num_input / 4 + (num_input % 4) / 2 + num_input % 2, num_output / 2 + num_output % 2, (size_t)2u); + else +#endif // __SSE2__ + if (num_input >= 2) + weight_data_tm.create(2 * 2 * maxk, num_input / 2 + num_input % 2, num_output / 2 + num_output % 2, (size_t)2u); + else + weight_data_tm.create(2 * maxk, num_input, num_output / 2 + num_output % 2, (size_t)2u); + } + else + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (num_input >= 16) + weight_data_tm.create(16 * maxk, num_input / 16 + (num_input % 16) / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output, (size_t)2u); + else +#endif // __AVX512F__ + if (num_input >= 8) + weight_data_tm.create(8 * maxk, num_input / 8 + (num_input % 8) / 4 + (num_input % 4) / 2 + num_input % 2, num_output, (size_t)2u); + else +#endif // __AVX__ + if (num_input >= 4) + weight_data_tm.create(4 * maxk, num_input / 4 + (num_input % 4) / 2 + num_input % 2, num_output, (size_t)2u); + else +#endif // __SSE2__ + if (num_input >= 2) + weight_data_tm.create(2 * maxk, num_input / 2 + num_input % 2, num_output, (size_t)2u); + else + weight_data_tm.create(maxk, num_input, num_output, (size_t)2u); + } + // *INDENT-ON* + // clang-format on + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; q + 15 < num_output; q += 16) + { + unsigned short* g00 = weight_data_tm.channel(q / 16); + + int p = 0; + for (; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 16; i++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 3 < num_input; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 1 < num_input; p += 2) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p < num_input; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX512F__ + for (; q + 7 < num_output; q += 8) + { +#if __AVX512F__ + unsigned short* g00 = weight_data_tm.channel(q / 16 + (q % 16) / 8); +#else + unsigned short* g00 = weight_data_tm.channel(q / 8); +#endif + + int p = 0; +#if __AVX512F__ + for (; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 16; i++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX512F__ + for (; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 3 < num_input; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 1 < num_input; p += 2) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p < num_input; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 8; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX__ + for (; q + 3 < num_output; q += 4) + { +#if __AVX512F__ + unsigned short* g00 = weight_data_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4); +#elif __AVX__ + unsigned short* g00 = weight_data_tm.channel(q / 8 + (q % 8) / 4); +#else + unsigned short* g00 = weight_data_tm.channel(q / 4); +#endif + + int p = 0; +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 16; i++) + { + for (int j = 0; j < 4; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX512F__ + for (; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 4; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX__ + for (; p + 3 < num_input; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p + 1 < num_input; p += 2) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 4; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p < num_input; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 4; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __SSE2__ + for (; q + 1 < num_output; q += 2) + { +#if __AVX512F__ + unsigned short* g00 = weight_data_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __AVX__ + unsigned short* g00 = weight_data_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2); +#elif __SSE2__ + unsigned short* g00 = weight_data_tm.channel(q / 4 + (q % 4) / 2); +#else + unsigned short* g00 = weight_data_tm.channel(q / 2); +#endif + + int p = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 2; j++) + { + for (int i = 0; i < 16; i++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX512F__ + for (; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 2; j++) + { + for (int i = 0; i < 8; i++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __AVX__ + for (; p + 3 < num_input; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 2; j++) + { + for (int i = 0; i < 4; i++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } +#endif // __SSE2__ + for (; p + 1 < num_input; p += 2) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 2; j++) + { + for (int i = 0; i < 2; i++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; p < num_input; p++) + { + for (int k = 0; k < maxk; k++) + { + for (int j = 0; j < 2; j++) + { + const float* k00 = weight_data_r2.channel(q + j).row(p); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + } + for (; q < num_output; q++) + { +#if __AVX512F__ + unsigned short* g00 = weight_data_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __AVX__ + unsigned short* g00 = weight_data_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2); +#elif __SSE2__ + unsigned short* g00 = weight_data_tm.channel(q / 4 + (q % 4) / 2 + q % 2); +#else + unsigned short* g00 = weight_data_tm.channel(q / 2 + q % 2); +#endif + + int p = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; p + 15 < num_input; p += 16) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 16; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } +#endif // __AVX512F__ + for (; p + 7 < num_input; p += 8) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 8; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } +#endif // __AVX__ + for (; p + 3 < num_input; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } +#endif // __SSE2__ + for (; p + 1 < num_input; p += 2) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 2; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } + for (; p < num_input; p++) + { + for (int k = 0; k < maxk; k++) + { + const float* k00 = weight_data_r2.channel(q).row(p); + g00[0] = float32_to_bfloat16(k00[k]); + g00++; + } + } + } +} + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void deconvolution_packed_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, const Option& opt); +#endif + +static void deconvolution_packed_bf16s(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + deconvolution_packed_bf16s_avx512bf16(bottom_blob, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); + return; + } +#endif + const int out_elempack = top_blob.elempack; + + const int outch = top_blob.c * out_elempack; + + const size_t M = top_blob.cstep * out_elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + + const int maxk = kernel_w * kernel_h; + + const float* bias_data_ptr = bias_data; + + int nn_outch = 0; + int remain_outch_start = 0; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_outch = outch / 16; + #pragma omp parallel for num_threads(opt.num_threads) + for (int pp = 0; pp < nn_outch; pp++) + { + const int p = pp * 16; + + // shadowed variable for less openmp task args + const int elempack = bottom_blob.elempack; + const int inch = bottom_blob.c * elempack; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + + unsigned short* outptr = top_blob.channel(p / out_elempack); + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + __m512 _sum0 = _mm512_setzero_ps(); + __m512 _sum1 = _mm512_setzero_ps(); + __m512 _sum2 = _mm512_setzero_ps(); + __m512 _sum3 = _mm512_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm512_loadu_ps(bias_data_ptr + p); + } + + const unsigned short* kptr = weight_data_tm.channel(p / 16); + + int q = 0; + for (; q + 15 < inch; q += 16) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16 * 16; + + if (elempack == 16) + { + const unsigned short* sptr = bottom_blob.channel(q / 16).row(sy) + sx * 16; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr[3])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(sptr[4])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(sptr[5])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(sptr[6])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(sptr[7])); + __m512 _val8 = _mm512_set1_ps(bfloat16_to_float32(sptr[8])); + __m512 _val9 = _mm512_set1_ps(bfloat16_to_float32(sptr[9])); + __m512 _vala = _mm512_set1_ps(bfloat16_to_float32(sptr[10])); + __m512 _valb = _mm512_set1_ps(bfloat16_to_float32(sptr[11])); + __m512 _valc = _mm512_set1_ps(bfloat16_to_float32(sptr[12])); + __m512 _vald = _mm512_set1_ps(bfloat16_to_float32(sptr[13])); + __m512 _vale = _mm512_set1_ps(bfloat16_to_float32(sptr[14])); + __m512 _valf = _mm512_set1_ps(bfloat16_to_float32(sptr[15])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + _sum0 = _mm512_fmadd_ps(_val8, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 8))), _sum0); + _sum1 = _mm512_fmadd_ps(_val9, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 9))), _sum1); + _sum2 = _mm512_fmadd_ps(_vala, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 10))), _sum2); + _sum3 = _mm512_fmadd_ps(_valb, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 11))), _sum3); + _sum0 = _mm512_fmadd_ps(_valc, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 12))), _sum0); + _sum1 = _mm512_fmadd_ps(_vald, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 13))), _sum1); + _sum2 = _mm512_fmadd_ps(_vale, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 14))), _sum2); + _sum3 = _mm512_fmadd_ps(_valf, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 15))), _sum3); + } + if (elempack == 8) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 8).row(sy) + sx * 8; + const unsigned short* sptr1 = bottom_blob.channel(q / 8 + 1).row(sy) + sx * 8; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr0[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr0[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr0[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr0[3])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(sptr0[4])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(sptr0[5])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(sptr0[6])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(sptr0[7])); + __m512 _val8 = _mm512_set1_ps(bfloat16_to_float32(sptr1[0])); + __m512 _val9 = _mm512_set1_ps(bfloat16_to_float32(sptr1[1])); + __m512 _vala = _mm512_set1_ps(bfloat16_to_float32(sptr1[2])); + __m512 _valb = _mm512_set1_ps(bfloat16_to_float32(sptr1[3])); + __m512 _valc = _mm512_set1_ps(bfloat16_to_float32(sptr1[4])); + __m512 _vald = _mm512_set1_ps(bfloat16_to_float32(sptr1[5])); + __m512 _vale = _mm512_set1_ps(bfloat16_to_float32(sptr1[6])); + __m512 _valf = _mm512_set1_ps(bfloat16_to_float32(sptr1[7])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + _sum0 = _mm512_fmadd_ps(_val8, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 8))), _sum0); + _sum1 = _mm512_fmadd_ps(_val9, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 9))), _sum1); + _sum2 = _mm512_fmadd_ps(_vala, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 10))), _sum2); + _sum3 = _mm512_fmadd_ps(_valb, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 11))), _sum3); + _sum0 = _mm512_fmadd_ps(_valc, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 12))), _sum0); + _sum1 = _mm512_fmadd_ps(_vald, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 13))), _sum1); + _sum2 = _mm512_fmadd_ps(_vale, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 14))), _sum2); + _sum3 = _mm512_fmadd_ps(_valf, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 15))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + const unsigned short* sptr2 = bottom_blob.channel(q / 4 + 2).row(sy) + sx * 4; + const unsigned short* sptr3 = bottom_blob.channel(q / 4 + 3).row(sy) + sx * 4; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr0[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr0[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr0[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr0[3])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(sptr1[0])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(sptr1[1])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(sptr1[2])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(sptr1[3])); + __m512 _val8 = _mm512_set1_ps(bfloat16_to_float32(sptr2[0])); + __m512 _val9 = _mm512_set1_ps(bfloat16_to_float32(sptr2[1])); + __m512 _vala = _mm512_set1_ps(bfloat16_to_float32(sptr2[2])); + __m512 _valb = _mm512_set1_ps(bfloat16_to_float32(sptr2[3])); + __m512 _valc = _mm512_set1_ps(bfloat16_to_float32(sptr3[0])); + __m512 _vald = _mm512_set1_ps(bfloat16_to_float32(sptr3[1])); + __m512 _vale = _mm512_set1_ps(bfloat16_to_float32(sptr3[2])); + __m512 _valf = _mm512_set1_ps(bfloat16_to_float32(sptr3[3])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + _sum0 = _mm512_fmadd_ps(_val8, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 8))), _sum0); + _sum1 = _mm512_fmadd_ps(_val9, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 9))), _sum1); + _sum2 = _mm512_fmadd_ps(_vala, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 10))), _sum2); + _sum3 = _mm512_fmadd_ps(_valb, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 11))), _sum3); + _sum0 = _mm512_fmadd_ps(_valc, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 12))), _sum0); + _sum1 = _mm512_fmadd_ps(_vald, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 13))), _sum1); + _sum2 = _mm512_fmadd_ps(_vale, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 14))), _sum2); + _sum3 = _mm512_fmadd_ps(_valf, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 15))), _sum3); + } + if (elempack == 1) + { + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 4).row(sy)[sx])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 5).row(sy)[sx])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 6).row(sy)[sx])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 7).row(sy)[sx])); + __m512 _val8 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 8).row(sy)[sx])); + __m512 _val9 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 9).row(sy)[sx])); + __m512 _vala = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 10).row(sy)[sx])); + __m512 _valb = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 11).row(sy)[sx])); + __m512 _valc = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 12).row(sy)[sx])); + __m512 _vald = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 13).row(sy)[sx])); + __m512 _vale = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 14).row(sy)[sx])); + __m512 _valf = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 15).row(sy)[sx])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + _sum0 = _mm512_fmadd_ps(_val8, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 8))), _sum0); + _sum1 = _mm512_fmadd_ps(_val9, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 9))), _sum1); + _sum2 = _mm512_fmadd_ps(_vala, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 10))), _sum2); + _sum3 = _mm512_fmadd_ps(_valb, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 11))), _sum3); + _sum0 = _mm512_fmadd_ps(_valc, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 12))), _sum0); + _sum1 = _mm512_fmadd_ps(_vald, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 13))), _sum1); + _sum2 = _mm512_fmadd_ps(_vale, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 14))), _sum2); + _sum3 = _mm512_fmadd_ps(_valf, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 15))), _sum3); + } + } + } + + kptr += maxk * 16 * 16; + } + for (; q + 7 < inch; q += 8) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8 * 16; + + if (elempack == 8) + { + const unsigned short* sptr = bottom_blob.channel(q / 8).row(sy) + sx * 8; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr[3])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(sptr[4])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(sptr[5])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(sptr[6])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(sptr[7])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr0[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr0[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr0[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr0[3])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(sptr1[0])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(sptr1[1])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(sptr1[2])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(sptr1[3])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + } + if (elempack == 1) + { + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + __m512 _val4 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 4).row(sy)[sx])); + __m512 _val5 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 5).row(sy)[sx])); + __m512 _val6 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 6).row(sy)[sx])); + __m512 _val7 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 7).row(sy)[sx])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + _sum0 = _mm512_fmadd_ps(_val4, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 4))), _sum0); + _sum1 = _mm512_fmadd_ps(_val5, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 5))), _sum1); + _sum2 = _mm512_fmadd_ps(_val6, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 6))), _sum2); + _sum3 = _mm512_fmadd_ps(_val7, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 7))), _sum3); + } + } + } + + kptr += maxk * 8 * 16; + } + for (; q + 3 < inch; q += 4) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4 * 16; + + if (elempack == 4) + { + const unsigned short* sptr = bottom_blob.channel(q / 4).row(sy) + sx * 4; + + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr[1])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(sptr[2])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(sptr[3])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + } + if (elempack == 1) + { + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m512 _val2 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m512 _val3 = _mm512_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + _sum2 = _mm512_fmadd_ps(_val2, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 2))), _sum2); + _sum3 = _mm512_fmadd_ps(_val3, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16 * 3))), _sum3); + } + } + } + + kptr += maxk * 4 * 16; + } + for (; q + 1 < inch; q += 2) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2 * 16; + + const unsigned short* sptr0 = bottom_blob.channel(q).row(sy) + sx; + const unsigned short* sptr1 = bottom_blob.channel(q + 1).row(sy) + sx; + __m512 _val0 = _mm512_set1_ps(bfloat16_to_float32(sptr0[0])); + __m512 _val1 = _mm512_set1_ps(bfloat16_to_float32(sptr1[0])); + _sum0 = _mm512_fmadd_ps(_val0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + _sum1 = _mm512_fmadd_ps(_val1, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1); + } + } + + kptr += maxk * 2 * 16; + } + for (; q < inch; q++) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16; + + const unsigned short* sptr = bottom_blob.channel(q).row(sy) + sx; + __m512 _val = _mm512_set1_ps(bfloat16_to_float32(sptr[0])); + _sum0 = _mm512_fmadd_ps(_val, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0); + } + } + + kptr += maxk * 16; + } + + _sum0 = _mm512_add_ps(_sum0, _sum1); + _sum2 = _mm512_add_ps(_sum2, _sum3); + _sum0 = _mm512_add_ps(_sum0, _sum2); + + _sum0 = activation_avx512(_sum0, activation_type, activation_params); + + if (out_elempack == 16) + { + _mm256_store_si256((__m256i*)outptr, float2bfloat_avx512(_sum0)); + outptr += 16; + } + if (out_elempack == 8) + { + _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm512_extractf32x8_ps(_sum0, 0))); + _mm_store_si128((__m128i*)(outptr + M), float2bfloat_avx(_mm512_extractf32x8_ps(_sum0, 1))); + outptr += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_mm512_extractf32x4_ps(_sum0, 0))); + _mm_storel_epi64((__m128i*)(outptr + M), float2bfloat_sse(_mm512_extractf32x4_ps(_sum0, 1))); + _mm_storel_epi64((__m128i*)(outptr + M * 2), float2bfloat_sse(_mm512_extractf32x4_ps(_sum0, 2))); + _mm_storel_epi64((__m128i*)(outptr + M * 3), float2bfloat_sse(_mm512_extractf32x4_ps(_sum0, 3))); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[16]; + _mm512_storeu_ps(sum, _sum0); + + outptr[0] = float32_to_bfloat16(sum[0]); + outptr[M] = float32_to_bfloat16(sum[1]); + outptr[M * 2] = float32_to_bfloat16(sum[2]); + outptr[M * 3] = float32_to_bfloat16(sum[3]); + outptr[M * 4] = float32_to_bfloat16(sum[4]); + outptr[M * 5] = float32_to_bfloat16(sum[5]); + outptr[M * 6] = float32_to_bfloat16(sum[6]); + outptr[M * 7] = float32_to_bfloat16(sum[7]); + outptr[M * 8] = float32_to_bfloat16(sum[8]); + outptr[M * 9] = float32_to_bfloat16(sum[9]); + outptr[M * 10] = float32_to_bfloat16(sum[10]); + outptr[M * 11] = float32_to_bfloat16(sum[11]); + outptr[M * 12] = float32_to_bfloat16(sum[12]); + outptr[M * 13] = float32_to_bfloat16(sum[13]); + outptr[M * 14] = float32_to_bfloat16(sum[14]); + outptr[M * 15] = float32_to_bfloat16(sum[15]); + outptr += 1; + } + } + } + } + remain_outch_start += nn_outch * 16; + nn_outch = (outch - remain_outch_start) / 8; +#else // __AVX512F__ + nn_outch = (outch - remain_outch_start) / 8; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __AVX512F__ + for (int pp = 0; pp < nn_outch; pp++) + { + const int p = remain_outch_start + pp * 8; + + // shadowed variable for less openmp task args + const int elempack = bottom_blob.elempack; + const int inch = bottom_blob.c * elempack; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + + unsigned short* outptr = top_blob.channel(p / out_elempack); + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + __m256 _sum0 = _mm256_setzero_ps(); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm256_loadu_ps(bias_data_ptr + p); + } + +#if __AVX512F__ + const unsigned short* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 8); +#endif + + int q = 0; +#if __AVX512F__ + for (; q + 15 < inch; q += 16) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16 * 8; + + if (elempack == 16) + { + const unsigned short* sptr = bottom_blob.channel(q / 16).row(sy) + sx * 16; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr[3])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr[4])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr[5])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr[6])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr[7])); + __m256 _val8 = _mm256_set1_ps(bfloat16_to_float32(sptr[8])); + __m256 _val9 = _mm256_set1_ps(bfloat16_to_float32(sptr[9])); + __m256 _vala = _mm256_set1_ps(bfloat16_to_float32(sptr[10])); + __m256 _valb = _mm256_set1_ps(bfloat16_to_float32(sptr[11])); + __m256 _valc = _mm256_set1_ps(bfloat16_to_float32(sptr[12])); + __m256 _vald = _mm256_set1_ps(bfloat16_to_float32(sptr[13])); + __m256 _vale = _mm256_set1_ps(bfloat16_to_float32(sptr[14])); + __m256 _valf = _mm256_set1_ps(bfloat16_to_float32(sptr[15])); + _sum0 = _mm256_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + _sum0 = _mm256_fmadd_ps(_val8, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 8))), _sum0); + _sum1 = _mm256_fmadd_ps(_val9, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 9))), _sum1); + _sum2 = _mm256_fmadd_ps(_vala, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 10))), _sum2); + _sum3 = _mm256_fmadd_ps(_valb, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 11))), _sum3); + _sum0 = _mm256_fmadd_ps(_valc, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 12))), _sum0); + _sum1 = _mm256_fmadd_ps(_vald, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 13))), _sum1); + _sum2 = _mm256_fmadd_ps(_vale, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 14))), _sum2); + _sum3 = _mm256_fmadd_ps(_valf, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 15))), _sum3); + } + if (elempack == 8) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 8).row(sy) + sx * 8; + const unsigned short* sptr1 = bottom_blob.channel(q / 8 + 1).row(sy) + sx * 8; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr0[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr0[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr0[3])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr0[4])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr0[5])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr0[6])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr0[7])); + __m256 _val8 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + __m256 _val9 = _mm256_set1_ps(bfloat16_to_float32(sptr1[1])); + __m256 _vala = _mm256_set1_ps(bfloat16_to_float32(sptr1[2])); + __m256 _valb = _mm256_set1_ps(bfloat16_to_float32(sptr1[3])); + __m256 _valc = _mm256_set1_ps(bfloat16_to_float32(sptr1[4])); + __m256 _vald = _mm256_set1_ps(bfloat16_to_float32(sptr1[5])); + __m256 _vale = _mm256_set1_ps(bfloat16_to_float32(sptr1[6])); + __m256 _valf = _mm256_set1_ps(bfloat16_to_float32(sptr1[7])); + _sum0 = _mm256_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + _sum0 = _mm256_fmadd_ps(_val8, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 8))), _sum0); + _sum1 = _mm256_fmadd_ps(_val9, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 9))), _sum1); + _sum2 = _mm256_fmadd_ps(_vala, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 10))), _sum2); + _sum3 = _mm256_fmadd_ps(_valb, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 11))), _sum3); + _sum0 = _mm256_fmadd_ps(_valc, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 12))), _sum0); + _sum1 = _mm256_fmadd_ps(_vald, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 13))), _sum1); + _sum2 = _mm256_fmadd_ps(_vale, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 14))), _sum2); + _sum3 = _mm256_fmadd_ps(_valf, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 15))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + const unsigned short* sptr2 = bottom_blob.channel(q / 4 + 2).row(sy) + sx * 4; + const unsigned short* sptr3 = bottom_blob.channel(q / 4 + 3).row(sy) + sx * 4; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr0[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr0[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr0[3])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr1[1])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr1[2])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr1[3])); + __m256 _val8 = _mm256_set1_ps(bfloat16_to_float32(sptr2[0])); + __m256 _val9 = _mm256_set1_ps(bfloat16_to_float32(sptr2[1])); + __m256 _vala = _mm256_set1_ps(bfloat16_to_float32(sptr2[2])); + __m256 _valb = _mm256_set1_ps(bfloat16_to_float32(sptr2[3])); + __m256 _valc = _mm256_set1_ps(bfloat16_to_float32(sptr3[0])); + __m256 _vald = _mm256_set1_ps(bfloat16_to_float32(sptr3[1])); + __m256 _vale = _mm256_set1_ps(bfloat16_to_float32(sptr3[2])); + __m256 _valf = _mm256_set1_ps(bfloat16_to_float32(sptr3[3])); + _sum0 = _mm256_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + _sum0 = _mm256_fmadd_ps(_val8, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 8))), _sum0); + _sum1 = _mm256_fmadd_ps(_val9, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 9))), _sum1); + _sum2 = _mm256_fmadd_ps(_vala, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 10))), _sum2); + _sum3 = _mm256_fmadd_ps(_valb, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 11))), _sum3); + _sum0 = _mm256_fmadd_ps(_valc, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 12))), _sum0); + _sum1 = _mm256_fmadd_ps(_vald, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 13))), _sum1); + _sum2 = _mm256_fmadd_ps(_vale, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 14))), _sum2); + _sum3 = _mm256_fmadd_ps(_valf, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 15))), _sum3); + } + if (elempack == 1) + { + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 4).row(sy)[sx])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 5).row(sy)[sx])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 6).row(sy)[sx])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 7).row(sy)[sx])); + __m256 _val8 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 8).row(sy)[sx])); + __m256 _val9 = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 9).row(sy)[sx])); + __m256 _vala = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 10).row(sy)[sx])); + __m256 _valb = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 11).row(sy)[sx])); + __m256 _valc = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 12).row(sy)[sx])); + __m256 _vald = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 13).row(sy)[sx])); + __m256 _vale = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 14).row(sy)[sx])); + __m256 _valf = _mm256_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 15).row(sy)[sx])); + _sum0 = _mm256_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + _sum0 = _mm256_fmadd_ps(_val8, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 8))), _sum0); + _sum1 = _mm256_fmadd_ps(_val9, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 9))), _sum1); + _sum2 = _mm256_fmadd_ps(_vala, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 10))), _sum2); + _sum3 = _mm256_fmadd_ps(_valb, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 11))), _sum3); + _sum0 = _mm256_fmadd_ps(_valc, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 12))), _sum0); + _sum1 = _mm256_fmadd_ps(_vald, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 13))), _sum1); + _sum2 = _mm256_fmadd_ps(_vale, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 14))), _sum2); + _sum3 = _mm256_fmadd_ps(_valf, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 15))), _sum3); + } + } + } + + kptr += maxk * 16 * 8; + } +#endif // __AVX512F__ + for (; q + 7 < inch; q += 8) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8 * 8; + + if (elempack == 8) + { + const unsigned short* sptr = bottom_blob.channel(q / 8).row(sy) + sx * 8; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr[3])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr[4])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr[5])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr[6])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr[7])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr0[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr0[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr0[3])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr1[1])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr1[2])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr1[3])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + } + if (elempack == 1) + { + const unsigned short* sptr0 = bottom_blob.channel(q).row(sy) + sx; + const unsigned short* sptr1 = bottom_blob.channel(q + 1).row(sy) + sx; + const unsigned short* sptr2 = bottom_blob.channel(q + 2).row(sy) + sx; + const unsigned short* sptr3 = bottom_blob.channel(q + 3).row(sy) + sx; + const unsigned short* sptr4 = bottom_blob.channel(q + 4).row(sy) + sx; + const unsigned short* sptr5 = bottom_blob.channel(q + 5).row(sy) + sx; + const unsigned short* sptr6 = bottom_blob.channel(q + 6).row(sy) + sx; + const unsigned short* sptr7 = bottom_blob.channel(q + 7).row(sy) + sx; + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr2[0])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr3[0])); + __m256 _val4 = _mm256_set1_ps(bfloat16_to_float32(sptr4[0])); + __m256 _val5 = _mm256_set1_ps(bfloat16_to_float32(sptr5[0])); + __m256 _val6 = _mm256_set1_ps(bfloat16_to_float32(sptr6[0])); + __m256 _val7 = _mm256_set1_ps(bfloat16_to_float32(sptr7[0])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + _sum0 = _mm256_comp_fmadd_ps(_val4, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 4))), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val5, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 5))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val6, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 6))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val7, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 7))), _sum3); + } + } + } + + kptr += maxk * 8 * 8; + } + for (; q + 3 < inch; q += 4) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4 * 8; + + if (elempack == 4) + { + const unsigned short* sptr = bottom_blob.channel(q / 4).row(sy) + sx * 4; + + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr[1])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr[2])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr[3])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + } + if (elempack == 1) + { + const unsigned short* sptr0 = bottom_blob.channel(q).row(sy) + sx; + const unsigned short* sptr1 = bottom_blob.channel(q + 1).row(sy) + sx; + const unsigned short* sptr2 = bottom_blob.channel(q + 2).row(sy) + sx; + const unsigned short* sptr3 = bottom_blob.channel(q + 3).row(sy) + sx; + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + __m256 _val2 = _mm256_set1_ps(bfloat16_to_float32(sptr2[0])); + __m256 _val3 = _mm256_set1_ps(bfloat16_to_float32(sptr3[0])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_val2, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 2))), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_val3, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8 * 3))), _sum3); + } + } + } + + kptr += maxk * 4 * 8; + } + for (; q + 1 < inch; q += 2) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2 * 8; + + const unsigned short* sptr0 = bottom_blob.channel(q).row(sy) + sx; + const unsigned short* sptr1 = bottom_blob.channel(q + 1).row(sy) + sx; + __m256 _val0 = _mm256_set1_ps(bfloat16_to_float32(sptr0[0])); + __m256 _val1 = _mm256_set1_ps(bfloat16_to_float32(sptr1[0])); + _sum0 = _mm256_comp_fmadd_ps(_val0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_val1, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1); + } + } + + kptr += maxk * 2 * 8; + } + for (; q < inch; q++) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8; + + const unsigned short* sptr = bottom_blob.channel(q).row(sy) + sx; + __m256 _val = _mm256_set1_ps(bfloat16_to_float32(sptr[0])); + _sum0 = _mm256_comp_fmadd_ps(_val, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0); + } + } + + kptr += maxk * 8; + } + + _sum0 = _mm256_add_ps(_sum0, _sum1); + _sum2 = _mm256_add_ps(_sum2, _sum3); + _sum0 = _mm256_add_ps(_sum0, _sum2); + + _sum0 = activation_avx(_sum0, activation_type, activation_params); + + if (out_elempack == 8) + { + _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_sum0)); + outptr += 8; + } + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_mm256_extractf128_ps(_sum0, 0))); + _mm_storel_epi64((__m128i*)(outptr + M), float2bfloat_sse(_mm256_extractf128_ps(_sum0, 1))); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[8]; + _mm256_storeu_ps(sum, _sum0); + + outptr[0] = float32_to_bfloat16(sum[0]); + outptr[M] = float32_to_bfloat16(sum[1]); + outptr[M * 2] = float32_to_bfloat16(sum[2]); + outptr[M * 3] = float32_to_bfloat16(sum[3]); + outptr[M * 4] = float32_to_bfloat16(sum[4]); + outptr[M * 5] = float32_to_bfloat16(sum[5]); + outptr[M * 6] = float32_to_bfloat16(sum[6]); + outptr[M * 7] = float32_to_bfloat16(sum[7]); + outptr += 1; + } + } + } + } + remain_outch_start += nn_outch * 8; + nn_outch = (outch - remain_outch_start) / 4; +#else // __AVX__ + nn_outch = (outch - remain_outch_start) / 4; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __AVX__ + for (int pp = 0; pp < nn_outch; pp++) + { + const int p = remain_outch_start + pp * 4; + + const int elempack = bottom_blob.elempack; + const int inch = bottom_blob.c * elempack; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + + unsigned short* outptr = top_blob.channel(p / out_elempack); + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + __m128 _sum0 = _mm_setzero_ps(); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); + + if (bias_data_ptr) + { + _sum0 = _mm_loadu_ps(bias_data_ptr + p); + } + +#if __AVX512F__ + const unsigned short* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4); +#elif __AVX__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 4); +#endif + + int q = 0; +#if __AVX__ +#if __AVX512F__ + for (; q + 15 < inch; q += 16) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16 * 4; + + if (elempack == 16) + { + const unsigned short* sptr = bottom_blob.channel(q / 16).row(sy) + sx * 16; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr[3])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(sptr[4])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(sptr[5])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(sptr[6])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(sptr[7])); + __m128 _val8 = _mm_set1_ps(bfloat16_to_float32(sptr[8])); + __m128 _val9 = _mm_set1_ps(bfloat16_to_float32(sptr[9])); + __m128 _vala = _mm_set1_ps(bfloat16_to_float32(sptr[10])); + __m128 _valb = _mm_set1_ps(bfloat16_to_float32(sptr[11])); + __m128 _valc = _mm_set1_ps(bfloat16_to_float32(sptr[12])); + __m128 _vald = _mm_set1_ps(bfloat16_to_float32(sptr[13])); + __m128 _vale = _mm_set1_ps(bfloat16_to_float32(sptr[14])); + __m128 _valf = _mm_set1_ps(bfloat16_to_float32(sptr[15])); + _sum0 = _mm_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + _sum0 = _mm_fmadd_ps(_val8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 8))), _sum0); + _sum1 = _mm_fmadd_ps(_val9, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 9))), _sum1); + _sum2 = _mm_fmadd_ps(_vala, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 10))), _sum2); + _sum3 = _mm_fmadd_ps(_valb, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 11))), _sum3); + _sum0 = _mm_fmadd_ps(_valc, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 12))), _sum0); + _sum1 = _mm_fmadd_ps(_vald, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 13))), _sum1); + _sum2 = _mm_fmadd_ps(_vale, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 14))), _sum2); + _sum3 = _mm_fmadd_ps(_valf, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 15))), _sum3); + } + if (elempack == 8) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 8).row(sy) + sx * 8; + const unsigned short* sptr1 = bottom_blob.channel(q / 8 + 1).row(sy) + sx * 8; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr0[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr0[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr0[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr0[3])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(sptr0[4])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(sptr0[5])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(sptr0[6])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(sptr0[7])); + __m128 _val8 = _mm_set1_ps(bfloat16_to_float32(sptr1[0])); + __m128 _val9 = _mm_set1_ps(bfloat16_to_float32(sptr1[1])); + __m128 _vala = _mm_set1_ps(bfloat16_to_float32(sptr1[2])); + __m128 _valb = _mm_set1_ps(bfloat16_to_float32(sptr1[3])); + __m128 _valc = _mm_set1_ps(bfloat16_to_float32(sptr1[4])); + __m128 _vald = _mm_set1_ps(bfloat16_to_float32(sptr1[5])); + __m128 _vale = _mm_set1_ps(bfloat16_to_float32(sptr1[6])); + __m128 _valf = _mm_set1_ps(bfloat16_to_float32(sptr1[7])); + _sum0 = _mm_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + _sum0 = _mm_fmadd_ps(_val8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 8))), _sum0); + _sum1 = _mm_fmadd_ps(_val9, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 9))), _sum1); + _sum2 = _mm_fmadd_ps(_vala, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 10))), _sum2); + _sum3 = _mm_fmadd_ps(_valb, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 11))), _sum3); + _sum0 = _mm_fmadd_ps(_valc, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 12))), _sum0); + _sum1 = _mm_fmadd_ps(_vald, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 13))), _sum1); + _sum2 = _mm_fmadd_ps(_vale, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 14))), _sum2); + _sum3 = _mm_fmadd_ps(_valf, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 15))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr_q0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr_q1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + const unsigned short* sptr_q2 = bottom_blob.channel(q / 4 + 2).row(sy) + sx * 4; + const unsigned short* sptr_q3 = bottom_blob.channel(q / 4 + 3).row(sy) + sx * 4; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr_q0[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr_q0[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr_q0[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr_q0[3])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(sptr_q1[0])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(sptr_q1[1])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(sptr_q1[2])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(sptr_q1[3])); + __m128 _val8 = _mm_set1_ps(bfloat16_to_float32(sptr_q2[0])); + __m128 _val9 = _mm_set1_ps(bfloat16_to_float32(sptr_q2[1])); + __m128 _vala = _mm_set1_ps(bfloat16_to_float32(sptr_q2[2])); + __m128 _valb = _mm_set1_ps(bfloat16_to_float32(sptr_q2[3])); + __m128 _valc = _mm_set1_ps(bfloat16_to_float32(sptr_q3[0])); + __m128 _vald = _mm_set1_ps(bfloat16_to_float32(sptr_q3[1])); + __m128 _vale = _mm_set1_ps(bfloat16_to_float32(sptr_q3[2])); + __m128 _valf = _mm_set1_ps(bfloat16_to_float32(sptr_q3[3])); + _sum0 = _mm_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + _sum0 = _mm_fmadd_ps(_val8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 8))), _sum0); + _sum1 = _mm_fmadd_ps(_val9, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 9))), _sum1); + _sum2 = _mm_fmadd_ps(_vala, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 10))), _sum2); + _sum3 = _mm_fmadd_ps(_valb, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 11))), _sum3); + _sum0 = _mm_fmadd_ps(_valc, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 12))), _sum0); + _sum1 = _mm_fmadd_ps(_vald, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 13))), _sum1); + _sum2 = _mm_fmadd_ps(_vale, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 14))), _sum2); + _sum3 = _mm_fmadd_ps(_valf, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 15))), _sum3); + } + if (elempack == 1) + { + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 4).row(sy)[sx])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 5).row(sy)[sx])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 6).row(sy)[sx])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 7).row(sy)[sx])); + __m128 _val8 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 8).row(sy)[sx])); + __m128 _val9 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 9).row(sy)[sx])); + __m128 _vala = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 10).row(sy)[sx])); + __m128 _valb = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 11).row(sy)[sx])); + __m128 _valc = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 12).row(sy)[sx])); + __m128 _vald = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 13).row(sy)[sx])); + __m128 _vale = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 14).row(sy)[sx])); + __m128 _valf = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 15).row(sy)[sx])); + _sum0 = _mm_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + _sum0 = _mm_fmadd_ps(_val8, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 8))), _sum0); + _sum1 = _mm_fmadd_ps(_val9, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 9))), _sum1); + _sum2 = _mm_fmadd_ps(_vala, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 10))), _sum2); + _sum3 = _mm_fmadd_ps(_valb, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 11))), _sum3); + _sum0 = _mm_fmadd_ps(_valc, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 12))), _sum0); + _sum1 = _mm_fmadd_ps(_vald, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 13))), _sum1); + _sum2 = _mm_fmadd_ps(_vale, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 14))), _sum2); + _sum3 = _mm_fmadd_ps(_valf, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 15))), _sum3); + } + } + } + kptr += maxk * 16 * 4; + } +#endif // __AVX512F__ + for (; q + 7 < inch; q += 8) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8 * 4; + + if (elempack == 8) + { + const unsigned short* sptr = bottom_blob.channel(q / 8).row(sy) + sx * 8; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr[3])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(sptr[4])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(sptr[5])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(sptr[6])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(sptr[7])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_comp_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr0[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr0[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr0[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr0[3])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(sptr1[0])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(sptr1[1])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(sptr1[2])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(sptr1[3])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_comp_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + } + if (elempack == 1) + { + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + __m128 _val4 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 4).row(sy)[sx])); + __m128 _val5 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 5).row(sy)[sx])); + __m128 _val6 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 6).row(sy)[sx])); + __m128 _val7 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 7).row(sy)[sx])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + _sum0 = _mm_comp_fmadd_ps(_val4, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 4))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val5, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 5))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val6, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 6))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val7, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 7))), _sum3); + } + } + } + kptr += maxk * 8 * 4; + } +#endif // __AVX__ + for (; q + 3 < inch; q += 4) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4 * 4; + + if (elempack == 4) + { + const unsigned short* sptr = bottom_blob.channel(q / 4).row(sy) + sx * 4; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr[1])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(sptr[2])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(sptr[3])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 8))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 12))), _sum3); + } + if (elempack == 1) + { + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q).row(sy)[sx])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 1).row(sy)[sx])); + __m128 _val2 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 2).row(sy)[sx])); + __m128 _val3 = _mm_set1_ps(bfloat16_to_float32(bottom_blob.channel(q + 3).row(sy)[sx])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 0))), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 1))), _sum1); + _sum2 = _mm_comp_fmadd_ps(_val2, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 2))), _sum2); + _sum3 = _mm_comp_fmadd_ps(_val3, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4 * 3))), _sum3); + } + } + } + kptr += maxk * 4 * 4; + } + for (; q + 1 < inch; q += 2) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2 * 4; + + const unsigned short* sptr0 = bottom_blob.channel(q).row(sy) + sx; + const unsigned short* sptr1 = bottom_blob.channel(q + 1).row(sy) + sx; + __m128 _val0 = _mm_set1_ps(bfloat16_to_float32(sptr0[0])); + __m128 _val1 = _mm_set1_ps(bfloat16_to_float32(sptr1[0])); + _sum0 = _mm_comp_fmadd_ps(_val0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum0); + _sum1 = _mm_comp_fmadd_ps(_val1, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4))), _sum1); + } + } + kptr += maxk * 2 * 4; + } + for (; q < inch; q++) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4; + + const unsigned short* sptr = bottom_blob.channel(q).row(sy) + sx; + __m128 _val = _mm_set1_ps(bfloat16_to_float32(sptr[0])); + _sum0 = _mm_comp_fmadd_ps(_val, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum0); + } + } + kptr += maxk * 4; + } + + _sum0 = _mm_add_ps(_sum0, _sum1); + _sum2 = _mm_add_ps(_sum2, _sum3); + _sum0 = _mm_add_ps(_sum0, _sum2); + + _sum0 = activation_sse(_sum0, activation_type, activation_params); + + if (out_elempack == 4) + { + _mm_storel_epi64((__m128i*)outptr, float2bfloat_sse(_sum0)); + outptr += 4; + } + if (out_elempack == 1) + { + float sum[4]; + _mm_storeu_ps(sum, _sum0); + + outptr[0] = float32_to_bfloat16(sum[0]); + outptr[M] = float32_to_bfloat16(sum[1]); + outptr[M * 2] = float32_to_bfloat16(sum[2]); + outptr[M * 3] = float32_to_bfloat16(sum[3]); + outptr += 1; + } + } + } + } + remain_outch_start += nn_outch * 4; + nn_outch = (outch - remain_outch_start) / 2; +#else // __SSE2__ + nn_outch = (outch - remain_outch_start) / 2; + #pragma omp parallel for num_threads(opt.num_threads) +#endif // __SSE2__ + for (int pp = 0; pp < nn_outch; pp++) + { + const int p = remain_outch_start + pp * 2; + + const int elempack = bottom_blob.elempack; + const int inch = bottom_blob.c * elempack; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int outw = top_blob.w; + const int outh = top_blob.h; + + unsigned short* outptr0 = top_blob.channel(p); + unsigned short* outptr1 = top_blob.channel(p + 1); + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + float sum0 = 0.f; + float sum1 = 0.f; + + if (bias_data_ptr) + { + sum0 = bias_data_ptr[p]; + sum1 = bias_data_ptr[p + 1]; + } + +#if __AVX512F__ + const unsigned short* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __AVX__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2); +#elif __SSE2__ + const unsigned short* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 2); +#endif + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum0_avx512 = _mm512_setzero_ps(); + __m512 _sum1_avx512 = _mm512_setzero_ps(); + for (; q + 15 < inch; q += 16) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16 * 2; + + if (elempack == 16) + { + const unsigned short* sptr = bottom_blob.channel(q / 16).row(sy) + sx * 16; + __m512 _r0 = bfloat2float_avx512(_mm256_load_si256((const __m256i*)sptr)); + _sum0_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1_avx512); + } + if (elempack == 8) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 8).row(sy) + sx * 8; + const unsigned short* sptr1 = bottom_blob.channel(q / 8 + 1).row(sy) + sx * 8; + __m512 _r0 = combine8x2_ps(bfloat2float_avx(_mm_load_si128((const __m128i*)sptr0)), bfloat2float_avx(_mm_load_si128((const __m128i*)sptr1))); + _sum0_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1_avx512); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + const unsigned short* sptr2 = bottom_blob.channel(q / 4 + 2).row(sy) + sx * 4; + const unsigned short* sptr3 = bottom_blob.channel(q / 4 + 3).row(sy) + sx * 4; + __m512 _r0 = combine4x4_ps(bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr0)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr1)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr2)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr3))); + _sum0_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1_avx512); + } + if (elempack == 1) + { + float tmp[16]; + for (int qi = 0; qi < 16; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m512 _r0 = _mm512_loadu_ps(tmp); + _sum0_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum0_avx512); + _sum1_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)(kptr0 + 16))), _sum1_avx512); + } + } + } + kptr += maxk * 16 * 2; + } + sum0 += _mm512_comp_reduce_add_ps(_sum0_avx512); + sum1 += _mm512_comp_reduce_add_ps(_sum1_avx512); +#endif // __AVX512F__ + __m256 _sum0_avx = _mm256_setzero_ps(); + __m256 _sum1_avx = _mm256_setzero_ps(); + for (; q + 7 < inch; q += 8) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8 * 2; + + if (elempack == 8) + { + const unsigned short* sptr = bottom_blob.channel(q / 8).row(sy) + sx * 8; + __m256 _r0 = bfloat2float_avx(_mm_load_si128((const __m128i*)sptr)); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1_avx); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + __m256 _r0 = combine4x2_ps(bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr0)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr1))); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1_avx); + } + if (elempack == 1) + { + float tmp[8]; + for (int qi = 0; qi < 8; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m256 _r0 = _mm256_loadu_ps(tmp); + _sum0_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum0_avx); + _sum1_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)(kptr0 + 8))), _sum1_avx); + } + } + } + kptr += maxk * 8 * 2; + } + sum0 += _mm256_reduce_add_ps(_sum0_avx); + sum1 += _mm256_reduce_add_ps(_sum1_avx); +#endif // __AVX__ + __m128 _sum0_sse = _mm_setzero_ps(); + __m128 _sum1_sse = _mm_setzero_ps(); + for (; q + 3 < inch; q += 4) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4 * 2; + + if (elempack == 4) + { + const unsigned short* sptr = bottom_blob.channel(q / 4).row(sy) + sx * 4; + __m128 _r0 = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr)); + _sum0_sse = _mm_comp_fmadd_ps(_r0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum0_sse); + _sum1_sse = _mm_comp_fmadd_ps(_r0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4))), _sum1_sse); + } + if (elempack == 1) + { + float tmp[4]; + for (int qi = 0; qi < 4; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m128 _r0 = _mm_loadu_ps(tmp); + _sum0_sse = _mm_comp_fmadd_ps(_r0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum0_sse); + _sum1_sse = _mm_comp_fmadd_ps(_r0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr0 + 4))), _sum1_sse); + } + } + } + kptr += maxk * 4 * 2; + } + sum0 += _mm_reduce_add_ps(_sum0_sse); + sum1 += _mm_reduce_add_ps(_sum1_sse); +#endif // __SSE2__ + for (; q + 1 < inch; q += 2) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2 * 2; + + for (int qi = 0; qi < 2; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + sum0 += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr0[qi]); + sum1 += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr0[2 + qi]); + } + } + } + kptr += maxk * 2 * 2; + } + for (; q < inch; q++) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2; + + const unsigned short* sptr = bottom_blob.channel(q).row(sy) + sx; + sum0 += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr0[0]); + sum1 += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr0[1]); + } + } + kptr += maxk * 2; + } + + sum0 = activation_ss(sum0, activation_type, activation_params); + sum1 = activation_ss(sum1, activation_type, activation_params); + + outptr0[0] = float32_to_bfloat16(sum0); + outptr1[0] = float32_to_bfloat16(sum1); + outptr0 += 1; + outptr1 += 1; + } + } + } + remain_outch_start += nn_outch * 2; + for (int p = remain_outch_start; p < outch; p++) + { + const int elempack = bottom_blob.elempack; + const int inch = bottom_blob.c * elempack; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int outw = top_blob.w; + const int outh = top_blob.h; + + unsigned short* outptr = top_blob.channel(p); + + for (int i = 0; i < outh; i++) + { + for (int j = 0; j < outw; j++) + { + float sum = 0.f; + + if (bias_data_ptr) + { + sum = bias_data_ptr[p]; + } + +#if __AVX512F__ + const unsigned short* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __AVX__ + const unsigned short* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2); +#elif __SSE2__ + const unsigned short* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2 + p % 2); +#else + const unsigned short* kptr = weight_data_tm.channel(p / 2 + p % 2); +#endif + + int q = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _sum_avx512 = _mm512_setzero_ps(); + for (; q + 15 < inch; q += 16) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 16; + + if (elempack == 16) + { + const unsigned short* sptr = bottom_blob.channel(q / 16).row(sy) + sx * 16; + _sum_avx512 = _mm512_fmadd_ps(bfloat2float_avx512(_mm256_load_si256((const __m256i*)sptr)), bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum_avx512); + } + if (elempack == 8) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 8).row(sy) + sx * 8; + const unsigned short* sptr1 = bottom_blob.channel(q / 8 + 1).row(sy) + sx * 8; + __m512 _r0 = combine8x2_ps(bfloat2float_avx(_mm_load_si128((const __m128i*)sptr0)), bfloat2float_avx(_mm_load_si128((const __m128i*)sptr1))); + _sum_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum_avx512); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + const unsigned short* sptr2 = bottom_blob.channel(q / 4 + 2).row(sy) + sx * 4; + const unsigned short* sptr3 = bottom_blob.channel(q / 4 + 3).row(sy) + sx * 4; + __m512 _r0 = combine4x4_ps(bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr0)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr1)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr2)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr3))); + _sum_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum_avx512); + } + if (elempack == 1) + { + float tmp[16]; + for (int qi = 0; qi < 16; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m512 _r0 = _mm512_loadu_ps(tmp); + _sum_avx512 = _mm512_fmadd_ps(_r0, bfloat2float_avx512(_mm256_load_si256((const __m256i*)kptr0)), _sum_avx512); + } + } + } + kptr += maxk * 16; + } + sum += _mm512_comp_reduce_add_ps(_sum_avx512); +#endif // __AVX512F__ + __m256 _sum_avx = _mm256_setzero_ps(); + for (; q + 7 < inch; q += 8) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 8; + + if (elempack == 8) + { + const unsigned short* sptr = bottom_blob.channel(q / 8).row(sy) + sx * 8; + _sum_avx = _mm256_comp_fmadd_ps(bfloat2float_avx(_mm_load_si128((const __m128i*)sptr)), bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum_avx); + } + if (elempack == 4) + { + const unsigned short* sptr0 = bottom_blob.channel(q / 4).row(sy) + sx * 4; + const unsigned short* sptr1 = bottom_blob.channel(q / 4 + 1).row(sy) + sx * 4; + __m256 _r0 = combine4x2_ps(bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr0)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr1))); + _sum_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum_avx); + } + if (elempack == 1) + { + float tmp[8]; + for (int qi = 0; qi < 8; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m256 _r0 = _mm256_loadu_ps(tmp); + _sum_avx = _mm256_comp_fmadd_ps(_r0, bfloat2float_avx(_mm_load_si128((const __m128i*)kptr0)), _sum_avx); + } + } + } + kptr += maxk * 8; + } + sum += _mm256_reduce_add_ps(_sum_avx); +#endif // __AVX__ + __m128 _sum_sse = _mm_setzero_ps(); + for (; q + 3 < inch; q += 4) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 4; + + if (elempack == 4) + { + const unsigned short* sptr = bottom_blob.channel(q / 4).row(sy) + sx * 4; + _sum_sse = _mm_comp_fmadd_ps(bfloat2float_sse(_mm_loadl_epi64((const __m128i*)sptr)), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum_sse); + } + if (elempack == 1) + { + float tmp[4]; + for (int qi = 0; qi < 4; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + tmp[qi] = bfloat16_to_float32(sptr[0]); + } + __m128 _r0 = _mm_loadu_ps(tmp); + _sum_sse = _mm_comp_fmadd_ps(_r0, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)kptr0)), _sum_sse); + } + } + } + kptr += maxk * 4; + } + sum += _mm_reduce_add_ps(_sum_sse); +#endif // __SSE2__ + for (; q + 1 < inch; q += 2) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + const unsigned short* kptr0 = kptr + k * 2; + + for (int qi = 0; qi < 2; qi++) + { + const unsigned short* sptr = bottom_blob.channel(q + qi).row(sy) + sx; + sum += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr0[qi]); + } + } + } + kptr += maxk * 2; + } + for (; q < inch; q++) + { + for (int y = 0; y < kernel_h; y++) + { + int sys = (i + y * dilation_h - (kernel_extent_h - 1)); + if (sys < 0 || sys % stride_h != 0) + continue; + int sy = sys / stride_h; + if (sy >= h) + continue; + for (int x = 0; x < kernel_w; x++) + { + int sxs = (j + x * dilation_w - (kernel_extent_w - 1)); + if (sxs < 0 || sxs % stride_w != 0) + continue; + int sx = sxs / stride_w; + if (sx >= w) + continue; + int k = y * kernel_w + x; + + const unsigned short* sptr = bottom_blob.channel(q).row(sy) + sx; + sum += bfloat16_to_float32(sptr[0]) * bfloat16_to_float32(kptr[k]); + } + } + kptr += maxk; + } + + sum = activation_ss(sum, activation_type, activation_params); + + outptr[0] = float32_to_bfloat16(sum); + outptr += 1; + } + } + } +} diff --git a/src/layer/x86/deconvolution_x86.cpp b/src/layer/x86/deconvolution_x86.cpp index c1e2fb46533..6019e94077c 100644 --- a/src/layer/x86/deconvolution_x86.cpp +++ b/src/layer/x86/deconvolution_x86.cpp @@ -12,6 +12,7 @@ #endif #endif // __SSE2__ +#include "cpu.h" #include "x86_activation.h" #include "x86_usability.h" @@ -19,12 +20,20 @@ namespace ncnn { #include "deconvolution_packed.h" +#if NCNN_BF16 +#include "deconvolution_packed_bf16s.h" +#endif + Deconvolution_x86::Deconvolution_x86() { #if __SSE2__ support_packing = true; #endif // __SSE2__ +#if NCNN_BF16 + support_bf16_storage = true; +#endif + activation = 0; gemm = 0; } @@ -36,6 +45,13 @@ int Deconvolution_x86::create_pipeline(const Option& opt) activation = create_activation_layer(activation_type, activation_params, opt); +#if NCNN_BF16 + if (opt.use_bf16_storage) + { + return create_pipeline_bf16s(opt); + } +#endif + const int maxk = kernel_w * kernel_h; int num_input = weight_data_size / maxk / num_output; @@ -139,6 +155,13 @@ int Deconvolution_x86::destroy_pipeline(const Option& opt) int Deconvolution_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_BF16 + if (opt.use_bf16_storage && bottom_blob.elembits() == 16) + { + return forward_bf16s(bottom_blob, top_blob, opt); + } +#endif + // deconvolv with NxN kernel // value = value + bias @@ -408,6 +431,15 @@ int Deconvolution_x86::forward(const std::vector& bottom_blobs, std::vector if (weight_data_flattened.empty()) return -100; +#if NCNN_BF16 + if (weight_data_flattened.elembits() == 16) + { + Mat tmp; + cast_bfloat16_to_float32(weight_data_flattened, tmp, opt); + weight_data_flattened = tmp; + } +#endif + // weight_data_flattened as pack1 weight_data_flattened.w *= weight_data_flattened.elempack; weight_data_flattened.elemsize /= weight_data_flattened.elempack; @@ -450,6 +482,15 @@ int Deconvolution_x86::forward(const std::vector& bottom_blobs, std::vector if (bias_data_flattened.empty()) return -100; +#if NCNN_BF16 + if (bias_data_flattened.elembits() == 16) + { + Mat tmp; + cast_bfloat16_to_float32(bias_data_flattened, tmp, opt); + bias_data_flattened = tmp; + } +#endif + // bias_data_flattened as pack1 bias_data_flattened.w *= bias_data_flattened.elempack; bias_data_flattened.elemsize /= bias_data_flattened.elempack; @@ -498,4 +539,69 @@ int Deconvolution_x86::forward(const std::vector& bottom_blobs, std::vector return 0; } +#if NCNN_BF16 +int Deconvolution_x86::create_pipeline_bf16s(const Option& opt) +{ + const int maxk = kernel_w * kernel_h; + const int num_input = weight_data_size / maxk / num_output; + + deconvolution_transform_kernel_packed_bf16s(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h); + + if (opt.lightmode) + weight_data.release(); + + return 0; +} + +int Deconvolution_x86::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + + int outw = (w - 1) * stride_w + kernel_extent_w + output_pad_right; + int outh = (h - 1) * stride_h + kernel_extent_h + output_pad_bottom; + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + out_elempack = num_output % 16 == 0 ? 16 : num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; +#else + out_elempack = num_output % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + size_t out_elemsize = 2u * out_elempack; + + int out_channels = num_output / out_elempack; + + Mat top_blob_bordered; + if (pad_left > 0 || pad_right > 0 || pad_top > 0 || pad_bottom > 0 || (output_w > 0 && output_h > 0)) + { + top_blob_bordered.create(outw, outh, out_channels, out_elemsize, out_elempack, opt.workspace_allocator); + } + else + { + top_blob_bordered = top_blob; + top_blob_bordered.create(outw, outh, out_channels, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob_bordered.empty()) + return -100; + + deconvolution_packed_bf16s(bottom_blob, top_blob_bordered, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); + + cut_padding(top_blob_bordered, top_blob, opt); + if (top_blob.empty()) + return -100; + + return 0; +} +#endif // NCNN_BF16 + } // namespace ncnn diff --git a/src/layer/x86/deconvolution_x86.h b/src/layer/x86/deconvolution_x86.h index ade9d458a1c..cb04feca029 100644 --- a/src/layer/x86/deconvolution_x86.h +++ b/src/layer/x86/deconvolution_x86.h @@ -20,6 +20,12 @@ class Deconvolution_x86 : public Deconvolution virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +public: +#if NCNN_BF16 + int create_pipeline_bf16s(const Option& opt); + int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; +#endif + public: Layer* activation; Layer* gemm; diff --git a/src/layer/x86/deconvolution_x86_avx512bf16.cpp b/src/layer/x86/deconvolution_x86_avx512bf16.cpp new file mode 100644 index 00000000000..072071017ba --- /dev/null +++ b/src/layer/x86/deconvolution_x86_avx512bf16.cpp @@ -0,0 +1,25 @@ +// Copyright 2022 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "layer.h" +#include "layer_type.h" +#include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "deconvolution_packed_bf16s.h" + +void deconvolution_packed_bf16s_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, const Option& opt) +{ + deconvolution_packed_bf16s(bottom_blob, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); +} + +void deconvolution_transform_kernel_packed_bf16s_avx512bf16(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h) +{ + deconvolution_transform_kernel_packed_bf16s(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h); +} + +} // namespace ncnn diff --git a/src/layer/x86/dequantize_x86_avx512bf16.cpp b/src/layer/x86/dequantize_x86_avx512bf16.cpp index 0d821807f3a..e72939c1a95 100644 --- a/src/layer/x86/dequantize_x86_avx512bf16.cpp +++ b/src/layer/x86/dequantize_x86_avx512bf16.cpp @@ -1,19 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "dequantize_x86.h" - -#if __SSE2__ -#include -#if __AVX__ -#include -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/dropout_x86_avx512bf16.cpp b/src/layer/x86/dropout_x86_avx512bf16.cpp index 52cf5052219..c38420ae7f9 100644 --- a/src/layer/x86/dropout_x86_avx512bf16.cpp +++ b/src/layer/x86/dropout_x86_avx512bf16.cpp @@ -1,19 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "dropout_x86.h" - -#if __SSE2__ -#include -#if __AVX__ -#include -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/eltwise_x86_avx512bf16.cpp b/src/layer/x86/eltwise_x86_avx512bf16.cpp index f70951fe8f7..174160bca82 100644 --- a/src/layer/x86/eltwise_x86_avx512bf16.cpp +++ b/src/layer/x86/eltwise_x86_avx512bf16.cpp @@ -1,19 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "eltwise_x86.h" - -#if __SSE2__ -#include -#if __AVX__ -#include -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/elu_x86_avx512bf16.cpp b/src/layer/x86/elu_x86_avx512bf16.cpp index c8e3aef0259..3883919d171 100644 --- a/src/layer/x86/elu_x86_avx512bf16.cpp +++ b/src/layer/x86/elu_x86_avx512bf16.cpp @@ -1,13 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "elu_x86.h" - -#include "x86_activation.h" -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/erf_x86_avx512bf16.cpp b/src/layer/x86/erf_x86_avx512bf16.cpp index 1840dd87241..79e12c5a4f2 100644 --- a/src/layer/x86/erf_x86_avx512bf16.cpp +++ b/src/layer/x86/erf_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "erf_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/gelu_x86_avx512bf16.cpp b/src/layer/x86/gelu_x86_avx512bf16.cpp index 739917ab6ac..582c192762c 100644 --- a/src/layer/x86/gelu_x86_avx512bf16.cpp +++ b/src/layer/x86/gelu_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "gelu_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp b/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp index 284733672b9..ec3cc68d965 100644 --- a/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp +++ b/src/layer/x86/hardsigmoid_x86_avx512bf16.cpp @@ -1,12 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "hardsigmoid_x86.h" - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/hardswish_x86_avx512bf16.cpp b/src/layer/x86/hardswish_x86_avx512bf16.cpp index 32cbde6887e..b5e892f65a7 100644 --- a/src/layer/x86/hardswish_x86_avx512bf16.cpp +++ b/src/layer/x86/hardswish_x86_avx512bf16.cpp @@ -1,12 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "hardswish_x86.h" - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/mish_x86_avx512bf16.cpp b/src/layer/x86/mish_x86_avx512bf16.cpp index 32cf22af607..c65e35573d8 100644 --- a/src/layer/x86/mish_x86_avx512bf16.cpp +++ b/src/layer/x86/mish_x86_avx512bf16.cpp @@ -1,11 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "mish_x86.h" - -#include "x86_activation.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" +#include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/quantize_x86_avx512bf16.cpp b/src/layer/x86/quantize_x86_avx512bf16.cpp index 1b3e617be69..42164beb409 100644 --- a/src/layer/x86/quantize_x86_avx512bf16.cpp +++ b/src/layer/x86/quantize_x86_avx512bf16.cpp @@ -1,19 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "quantize_x86.h" - -#if __SSE2__ -#include -#if __AVX__ -#include -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/relu_x86_avx512bf16.cpp b/src/layer/x86/relu_x86_avx512bf16.cpp index 7cd9976329e..0a91d46c8b1 100644 --- a/src/layer/x86/relu_x86_avx512bf16.cpp +++ b/src/layer/x86/relu_x86_avx512bf16.cpp @@ -1,12 +1,9 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "relu_x86.h" - -#include "x86_usability.h" - #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/selu_x86_avx512bf16.cpp b/src/layer/x86/selu_x86_avx512bf16.cpp index 37c0362e08b..f363a8c9622 100644 --- a/src/layer/x86/selu_x86_avx512bf16.cpp +++ b/src/layer/x86/selu_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "selu_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/sigmoid_x86_avx512bf16.cpp b/src/layer/x86/sigmoid_x86_avx512bf16.cpp index 0b86d15bdef..7257eea3e42 100644 --- a/src/layer/x86/sigmoid_x86_avx512bf16.cpp +++ b/src/layer/x86/sigmoid_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "sigmoid_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/softmax_x86_avx512bf16.cpp b/src/layer/x86/softmax_x86_avx512bf16.cpp index 2f69d47749e..1ef4369a0d2 100644 --- a/src/layer/x86/softmax_x86_avx512bf16.cpp +++ b/src/layer/x86/softmax_x86_avx512bf16.cpp @@ -1,26 +1,14 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "softmax_x86.h" - #include -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/swish_x86_avx512bf16.cpp b/src/layer/x86/swish_x86_avx512bf16.cpp index e95f23b5598..ab38f0226fb 100644 --- a/src/layer/x86/swish_x86_avx512bf16.cpp +++ b/src/layer/x86/swish_x86_avx512bf16.cpp @@ -1,24 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "swish_x86.h" - -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE2__ - -#include "x86_usability.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" #include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/tanh_x86_avx512bf16.cpp b/src/layer/x86/tanh_x86_avx512bf16.cpp index 50684126b16..68b138d27b1 100644 --- a/src/layer/x86/tanh_x86_avx512bf16.cpp +++ b/src/layer/x86/tanh_x86_avx512bf16.cpp @@ -1,11 +1,12 @@ // Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause -#include "tanh_x86.h" - -#include "x86_activation.h" - #include "cpu.h" +#include "layer.h" +#include "layer_type.h" +#include "mat.h" +#include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/unaryop_x86_avx512bf16.cpp b/src/layer/x86/unaryop_x86_avx512bf16.cpp index 9707ce82722..b7860cc7eac 100644 --- a/src/layer/x86/unaryop_x86_avx512bf16.cpp +++ b/src/layer/x86/unaryop_x86_avx512bf16.cpp @@ -3,23 +3,12 @@ #include "unaryop_x86.h" -#if __SSE2__ -#include -#include "sse_mathfun.h" -#if __SSE4_1__ -#include -#if __AVX__ -#include -#include "avx_mathfun.h" -#if __AVX512F__ -#include "avx512_mathfun.h" -#endif // __AVX512F__ -#endif // __AVX__ -#endif // __SSE4_1__ -#endif // __SSE2__ - -#include "x86_usability.h" +#include "cpu.h" +#include "layer.h" +#include "layer_type.h" +#include "mat.h" #include "x86_activation.h" +#include "x86_usability.h" namespace ncnn { From b6dab49bce9f7a8f17fa1b0f6f00c3109cfdf503 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Wed, 1 Apr 2026 15:54:03 +0800 Subject: [PATCH 33/36] WIP: save local changes before rebase --- src/layer/riscv/sdpa_riscv.cpp | 213 ++++++++++++++++++++++----------- src/layer/riscv/sdpa_riscv.h | 3 +- toolchains/k1.toolchain.cmake | 4 +- 3 files changed, 146 insertions(+), 74 deletions(-) diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp index b4d63e31566..4aae392bc29 100644 --- a/src/layer/riscv/sdpa_riscv.cpp +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -1,4 +1,5 @@ -// Copyright 2026 Tencent + +// Copyright 2025 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "sdpa_riscv.h" @@ -7,22 +8,26 @@ #if __riscv_vector #include -#endif #include "riscv_usability.h" +#endif // __riscv_vector namespace ncnn { SDPA_riscv::SDPA_riscv() { - support_packing = true; - qk_gemm = 0; qkv_gemm = 0; qk_softmax = 0; } -int SDPA_riscv::create_pipeline(const Option& opt) +int SDPA_riscv::create_pipeline(const Option& _opt) { + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + { qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); ncnn::ParamDict pd; @@ -34,23 +39,24 @@ int SDPA_riscv::create_pipeline(const Option& opt) } // Q * K^T + if (scale != 0.f) { qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, 1.f); // alpha (will be set in forward) - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, 1.f); // alpha (will be set in forward) + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA (Q: Seq x Embed) + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -92,8 +98,14 @@ int SDPA_riscv::create_pipeline(const Option& opt) return 0; } -int SDPA_riscv::destroy_pipeline(const Option& opt) +int SDPA_riscv::destroy_pipeline(const Option& _opt) { + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + if (qk_softmax) { qk_softmax->destroy_pipeline(opt); @@ -121,6 +133,11 @@ int SDPA_riscv::destroy_pipeline(const Option& opt) int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + const Mat& query = bottom_blobs[0]; const Mat& cur_key = bottom_blobs[1]; const Mat& cur_value = bottom_blobs[2]; @@ -142,7 +159,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& { // Fallback for packed data // TODO: Implement optimized RVV paths for group=2 with elempack=2,4,8, and group=4 with elempack=4 - + // Unpack input blobs std::vector bottom_blobs_unpacked = bottom_blobs; Option opt_unpack = opt; @@ -208,8 +225,44 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& const Mat cur_key_head = cur_key.channel(q); Mat key_head = key.channel(q); - memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); - memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); + // memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); + // memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); + + const float* past_ptr = past_key_head; + float* key_ptr = key_head.row(0); + int len = embed_dim * past_seqlen; + +#if __riscv_vector + int n = len; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t v = __riscv_vle32_v_f32m8(past_ptr, vl); + __riscv_vse32_v_f32m8(key_ptr, v, vl); + past_ptr += vl; + key_ptr += vl; + n -= vl; + } +#else + memcpy(key_ptr, past_ptr, len * sizeof(float)); +#endif + + const float* cur_ptr = cur_key_head; + key_ptr = key_head.row(past_seqlen); + len = embed_dim * cur_seqlen; + +#if __riscv_vector + n = len; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t v = __riscv_vle32_v_f32m8(cur_ptr, vl); + __riscv_vse32_v_f32m8(key_ptr, v, vl); + cur_ptr += vl; + key_ptr += vl; + n -= vl; + } +#else + memcpy(key_ptr, cur_ptr, len * sizeof(float)); +#endif } } else @@ -231,8 +284,44 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& const Mat cur_value_head = cur_value.channel(q); Mat value_head = value.channel(q); - memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); - memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); + // memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); + // memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); + + const float* past_ptr = past_value_head; + float* value_ptr = value_head.row(0); + int len = out_embed_dim * past_seqlen; + +#if __riscv_vector + int n = len; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t v = __riscv_vle32_v_f32m8(past_ptr, vl); + __riscv_vse32_v_f32m8(value_ptr, v, vl); + past_ptr += vl; + value_ptr += vl; + n -= vl; + } +#else + memcpy(value_ptr, past_ptr, len * sizeof(float)); +#endif + + const float* cur_ptr = cur_value_head; + value_ptr = value_head.row(past_seqlen); + len = out_embed_dim * cur_seqlen; + +#if __riscv_vector + n = len; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t v = __riscv_vle32_v_f32m8(cur_ptr, vl); + __riscv_vse32_v_f32m8(value_ptr, v, vl); + cur_ptr += vl; + value_ptr += vl; + n -= vl; + } +#else + memcpy(value_ptr, cur_ptr, len * sizeof(float)); +#endif } } else @@ -253,63 +342,51 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& std::vector retqks(num_heads); - float _scale = scale; - if (_scale == 0.f) - { - _scale = 1.f / sqrt(embed_dim); - } - - // Create local Gemm if scale is dynamic or different from 1.f + // Dynamic Scale Calculation and Beta Correction Layer* _qk_gemm = qk_gemm; - bool local_gemm = false; - if (_scale != 1.f) + if (scale == 0.f) { + float _scale = 1.f / sqrt(embed_dim); + _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, _scale); // alpha - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, _scale); // alpha + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif _qk_gemm->load_param(pd); _qk_gemm->load_model(ModelBinFromMatArray(0)); + Option opt1 = opt; opt1.num_threads = 1; _qk_gemm->create_pipeline(opt1); - local_gemm = true; } #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { // 1. Q * K^T - const Mat q_head = query.channel(i); - const Mat k_head = key.channel(i / num_heads_per_group); - Mat score_head = qk_cross.channel(i); + std::vector qk_bottom_blobs; + qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] + qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] - for (int j = 0; j < src_seqlen; j++) + if (attn_mask) { - const float* qptr = q_head.row(j); - float* outptr = score_head.row(j); - const float* mptr_row = 0; - if (attn_mask) - { - const Mat& maskm = attn_mask_blob.c > 1 ? attn_mask_blob.channel(i) : attn_mask_blob; - mptr_row = maskm.row(j); - } - - for (int k = 0; k < dst_seqlen; k++) + // Ensure mask is 2D for Gemm auto-broadcast detection + Mat maskm = attn_mask_blob; + if (maskm.dims == 3) { const float* kptr = k_head.row(k); float sum = 0.f; @@ -318,7 +395,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < embed_dim;) + for (; l < embed_dim; ) { size_t vl = __riscv_vsetvl_e32m8(embed_dim - l); vfloat32m8_t _q = __riscv_vle32_v_f32m8(qptr + l, vl); @@ -348,20 +425,14 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& return retqks[i]; } - if (local_gemm) - { - Option opt1 = opt; - opt1.num_threads = 1; - _qk_gemm->destroy_pipeline(opt1); - delete _qk_gemm; - } - // 2. Softmax int retqk = qk_softmax->forward_inplace(qk_cross, opt); if (retqk != 0) return retqk; // 3. Attn * V + std::vector retqkvs(num_heads); + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { @@ -381,7 +452,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < dst_seqlen;) + for (; l < dst_seqlen; ) { size_t vl = __riscv_vsetvl_e32m8(dst_seqlen - l); vfloat32m8_t _qk = __riscv_vle32_v_f32m8(qkptr + l, vl); diff --git a/src/layer/riscv/sdpa_riscv.h b/src/layer/riscv/sdpa_riscv.h index 796a31b3eae..d30122d212f 100644 --- a/src/layer/riscv/sdpa_riscv.h +++ b/src/layer/riscv/sdpa_riscv.h @@ -1,4 +1,5 @@ -// Copyright 2026 Tencent + +// Copyright 2025 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_SDPA_RISCV_H diff --git a/toolchains/k1.toolchain.cmake b/toolchains/k1.toolchain.cmake index b45da6c9b8a..ab44ab8e604 100644 --- a/toolchains/k1.toolchain.cmake +++ b/toolchains/k1.toolchain.cmake @@ -29,8 +29,8 @@ if(NOT CMAKE_FIND_ROOT_PATH_MODE_PACKAGE) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) endif() -set(CMAKE_C_FLAGS "-march=rv64gc_zba_zbb_zbc_zbs_zicbop -mabi=lp64d -mtune=spacemit-x60") -set(CMAKE_CXX_FLAGS "-march=rv64gc_zba_zbb_zbc_zbs_zicbop -mabi=lp64d -mtune=spacemit-x60") +set(CMAKE_C_FLAGS "-march=rv64gcv_zba_zbb_zbc_zbs_zicbop -mabi=lp64d ") +set(CMAKE_CXX_FLAGS "-march=rv64gcv_zba_zbb_zbc_zbs_zicbop -mabi=lp64d ") # cache flags set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") From 1f7d768962d33d63dec57f4dc39cb38c93a467b0 Mon Sep 17 00:00:00 2001 From: chenglimin Date: Wed, 1 Apr 2026 16:30:33 +0800 Subject: [PATCH 34/36] Update k1 toolchain config --- toolchains/k1.toolchain.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/toolchains/k1.toolchain.cmake b/toolchains/k1.toolchain.cmake index ab44ab8e604..b45da6c9b8a 100644 --- a/toolchains/k1.toolchain.cmake +++ b/toolchains/k1.toolchain.cmake @@ -29,8 +29,8 @@ if(NOT CMAKE_FIND_ROOT_PATH_MODE_PACKAGE) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) endif() -set(CMAKE_C_FLAGS "-march=rv64gcv_zba_zbb_zbc_zbs_zicbop -mabi=lp64d ") -set(CMAKE_CXX_FLAGS "-march=rv64gcv_zba_zbb_zbc_zbs_zicbop -mabi=lp64d ") +set(CMAKE_C_FLAGS "-march=rv64gc_zba_zbb_zbc_zbs_zicbop -mabi=lp64d -mtune=spacemit-x60") +set(CMAKE_CXX_FLAGS "-march=rv64gc_zba_zbb_zbc_zbs_zicbop -mabi=lp64d -mtune=spacemit-x60") # cache flags set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") From 517bbd623d4c261117265bcf03b2c14c884c8e30 Mon Sep 17 00:00:00 2001 From: chenglimin <18213449+chenglimin@users.noreply.github.com> Date: Wed, 1 Apr 2026 08:32:42 +0000 Subject: [PATCH 35/36] apply code-format changes --- src/layer/riscv/sdpa_riscv.cpp | 74 ++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp index 4aae392bc29..8eb09c1741e 100644 --- a/src/layer/riscv/sdpa_riscv.cpp +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -44,19 +44,19 @@ int SDPA_riscv::create_pipeline(const Option& _opt) qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, 1.f); // alpha (will be set in forward) - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, 1.f); // alpha (will be set in forward) + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA (Q: Seq x Embed) + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -159,7 +159,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& { // Fallback for packed data // TODO: Implement optimized RVV paths for group=2 with elempack=2,4,8, and group=4 with elempack=4 - + // Unpack input blobs std::vector bottom_blobs_unpacked = bottom_blobs; Option opt_unpack = opt; @@ -227,14 +227,15 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& // memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); // memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); - + const float* past_ptr = past_key_head; float* key_ptr = key_head.row(0); int len = embed_dim * past_seqlen; - + #if __riscv_vector int n = len; - while (n > 0) { + while (n > 0) + { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t v = __riscv_vle32_v_f32m8(past_ptr, vl); __riscv_vse32_v_f32m8(key_ptr, v, vl); @@ -252,7 +253,8 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& #if __riscv_vector n = len; - while (n > 0) { + while (n > 0) + { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t v = __riscv_vle32_v_f32m8(cur_ptr, vl); __riscv_vse32_v_f32m8(key_ptr, v, vl); @@ -293,7 +295,8 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& #if __riscv_vector int n = len; - while (n > 0) { + while (n > 0) + { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t v = __riscv_vle32_v_f32m8(past_ptr, vl); __riscv_vse32_v_f32m8(value_ptr, v, vl); @@ -311,7 +314,8 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& #if __riscv_vector n = len; - while (n > 0) { + while (n > 0) + { size_t vl = __riscv_vsetvl_e32m8(n); vfloat32m8_t v = __riscv_vle32_v_f32m8(cur_ptr, vl); __riscv_vse32_v_f32m8(value_ptr, v, vl); @@ -350,19 +354,19 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, _scale); // alpha - pd.set(1, 0.f); // beta - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack + pd.set(0, _scale); // alpha + pd.set(1, 0.f); // beta + pd.set(2, 0); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 0); // constantB + pd.set(6, 1); // constantC (None) + pd.set(7, 0); // M + pd.set(8, 0); // N + pd.set(9, 0); // K + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M + pd.set(12, 1); // output_elempack #if NCNN_INT8 pd.set(18, int8_scale_term); #endif @@ -395,7 +399,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < embed_dim; ) + for (; l < embed_dim;) { size_t vl = __riscv_vsetvl_e32m8(embed_dim - l); vfloat32m8_t _q = __riscv_vle32_v_f32m8(qptr + l, vl); @@ -452,7 +456,7 @@ int SDPA_riscv::forward(const std::vector& bottom_blobs, std::vector& size_t vlmax = __riscv_vsetvlmax_e32m8(); vfloat32m8_t _sum_v = __riscv_vfmv_v_f_f32m8(0.0f, vlmax); int l = 0; - for (; l < dst_seqlen; ) + for (; l < dst_seqlen;) { size_t vl = __riscv_vsetvl_e32m8(dst_seqlen - l); vfloat32m8_t _qk = __riscv_vle32_v_f32m8(qkptr + l, vl); From 64ac9a16556598d37b0d5f45a38095ec1997f53f Mon Sep 17 00:00:00 2001 From: chenglimin Date: Wed, 1 Apr 2026 16:46:07 +0800 Subject: [PATCH 36/36] change year --- src/layer/riscv/sdpa_riscv.cpp | 3 +-- src/layer/riscv/sdpa_riscv.h | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/layer/riscv/sdpa_riscv.cpp b/src/layer/riscv/sdpa_riscv.cpp index 8eb09c1741e..59653251e48 100644 --- a/src/layer/riscv/sdpa_riscv.cpp +++ b/src/layer/riscv/sdpa_riscv.cpp @@ -1,5 +1,4 @@ - -// Copyright 2025 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #include "sdpa_riscv.h" diff --git a/src/layer/riscv/sdpa_riscv.h b/src/layer/riscv/sdpa_riscv.h index d30122d212f..796a31b3eae 100644 --- a/src/layer/riscv/sdpa_riscv.h +++ b/src/layer/riscv/sdpa_riscv.h @@ -1,5 +1,4 @@ - -// Copyright 2025 Tencent +// Copyright 2026 Tencent // SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_SDPA_RISCV_H