Skip to content

Commit 8cd4d82

Browse files
committed
Optimize GPU Permute Kernel for B↔F Axis Swap ([1,0,2,3] order)
Signed-off-by: Andrew Park <andrew.park@intel.com>
1 parent eb86915 commit 8cd4d82

File tree

5 files changed

+404
-0
lines changed

5 files changed

+404
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
// Permute kernel for B <-> F axis swap (order [1,0,2,3] and higher-dim equivalents).
6+
// X is contiguous before and after the swap, so no SLM transpose is needed.
7+
// Each work item vectorizes along X with vload/vstore and loops over F internally.
8+
//
9+
// GWS: (ceil(X / VEC_WIDTH), Y [* Z [* W]], B)
10+
11+
#include "include/batch_headers/fetch_data.cl"
12+
13+
KERNEL(permute_b_f_axes)(
14+
OPTIONAL_SHAPE_INFO_ARG
15+
const __global INPUT0_TYPE* input,
16+
__global OUTPUT_TYPE* output
17+
#if HAS_FUSED_OPS_DECLS
18+
, FUSED_OPS_DECLS
19+
#endif
20+
)
21+
{
22+
const uint x_t = get_global_id(0); // X tile index
23+
24+
#if INPUT0_DIMS == 4
25+
const uint y = get_global_id(1);
26+
#elif INPUT0_DIMS == 5
27+
const uint z = get_global_id(1) / INPUT0_SIZE_Y;
28+
const uint y = get_global_id(1) % INPUT0_SIZE_Y;
29+
#elif INPUT0_DIMS == 6
30+
const uint w = get_global_id(1) / (INPUT0_SIZE_Z * INPUT0_SIZE_Y);
31+
const uint z = get_global_id(1) / INPUT0_SIZE_Y % INPUT0_SIZE_Z;
32+
const uint y = get_global_id(1) % INPUT0_SIZE_Y;
33+
#endif
34+
35+
const uint b = get_global_id(2);
36+
const uint x_base = x_t * VEC_WIDTH;
37+
38+
for (uint f = 0; f < INPUT0_FEATURE_NUM; ++f) {
39+
40+
#if X_REMAINDER_SIZE > 0
41+
if (x_t == X_TILES) {
42+
for (uint i = 0; i < X_REMAINDER_SIZE; ++i) {
43+
const uint x = x_base + i;
44+
#if INPUT0_DIMS == 4
45+
const uint in_idx = INPUT0_GET_INDEX(b, f, y, x);
46+
const uint out_idx = OUTPUT_GET_INDEX(f, b, y, x);
47+
#elif INPUT0_DIMS == 5
48+
const uint in_idx = INPUT0_GET_INDEX(b, f, z, y, x);
49+
const uint out_idx = OUTPUT_GET_INDEX(f, b, z, y, x);
50+
#elif INPUT0_DIMS == 6
51+
const uint in_idx = INPUT0_GET_INDEX(b, f, w, z, y, x);
52+
const uint out_idx = OUTPUT_GET_INDEX(f, b, w, z, y, x);
53+
#endif
54+
INPUT0_TYPE val = input[in_idx];
55+
#if HAS_FUSED_OPS
56+
INPUT0_TYPE input_var = val;
57+
FUSED_OPS;
58+
output[out_idx] = FUSED_OPS_RESULT;
59+
#else
60+
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
61+
#endif
62+
}
63+
continue;
64+
}
65+
#endif // X_REMAINDER_SIZE > 0
66+
67+
#if INPUT0_DIMS == 4
68+
const uint in_idx = INPUT0_GET_INDEX(b, f, y, x_base);
69+
const uint out_idx = OUTPUT_GET_INDEX(f, b, y, x_base);
70+
#elif INPUT0_DIMS == 5
71+
const uint in_idx = INPUT0_GET_INDEX(b, f, z, y, x_base);
72+
const uint out_idx = OUTPUT_GET_INDEX(f, b, z, y, x_base);
73+
#elif INPUT0_DIMS == 6
74+
const uint in_idx = INPUT0_GET_INDEX(b, f, w, z, y, x_base);
75+
const uint out_idx = OUTPUT_GET_INDEX(f, b, w, z, y, x_base);
76+
#endif
77+
78+
INPUTVTYPE vals = CAT(vload, VEC_WIDTH)(0, input + in_idx);
79+
#if HAS_FUSED_OPS
80+
OUTPUTVTYPE out_vals;
81+
__attribute__((opencl_unroll_hint(VEC_WIDTH)))
82+
for (uint i = 0; i < VEC_WIDTH; ++i) {
83+
INPUT0_TYPE input_var = vals[i];
84+
FUSED_OPS;
85+
out_vals[i] = FUSED_OPS_RESULT;
86+
}
87+
CAT(vstore, VEC_WIDTH)(out_vals, 0, output + out_idx);
88+
#else
89+
OUTPUTVTYPE out_vals = ACTIVATION(CAT(convert_, OUTPUTVTYPE)(vals), ACTIVATION_PARAMS);
90+
CAT(vstore, VEC_WIDTH)(out_vals, 0, output + out_idx);
91+
#endif
92+
} // for f
93+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "permute_kernel_b_f_axes.h"
6+
7+
#include <string>
8+
9+
#include "common_tools.h"
10+
#include "kernel_selector_utils.h"
11+
12+
namespace kernel_selector {
13+
14+
// Vector width chosen to make each vload/vstore a 16-byte transaction.
15+
static size_t GetVecWidth(const permute_params& params) {
16+
switch (params.inputs[0].GetDType()) {
17+
case Datatype::F16:
18+
case Datatype::INT16:
19+
case Datatype::UINT16:
20+
return 8;
21+
case Datatype::F32:
22+
case Datatype::INT32:
23+
return 4;
24+
case Datatype::INT8:
25+
case Datatype::UINT8:
26+
return 16;
27+
case Datatype::INT64:
28+
return 2;
29+
default:
30+
return 4;
31+
}
32+
}
33+
34+
ParamsKey PermuteKernel_b_f_axes::GetSupportedKey() const {
35+
ParamsKey k;
36+
k.EnableInputDataType(Datatype::F16);
37+
k.EnableInputDataType(Datatype::F32);
38+
k.EnableInputDataType(Datatype::INT8);
39+
k.EnableInputDataType(Datatype::UINT8);
40+
k.EnableInputDataType(Datatype::INT32);
41+
k.EnableInputDataType(Datatype::INT64);
42+
k.EnableOutputDataType(Datatype::F16);
43+
k.EnableOutputDataType(Datatype::F32);
44+
k.EnableOutputDataType(Datatype::INT8);
45+
k.EnableOutputDataType(Datatype::UINT8);
46+
k.EnableOutputDataType(Datatype::INT32);
47+
k.EnableOutputDataType(Datatype::INT64);
48+
k.EnableDifferentTypes();
49+
k.EnableInputLayout(DataLayout::bfyx);
50+
k.EnableOutputLayout(DataLayout::bfyx);
51+
k.EnableInputLayout(DataLayout::bfzyx);
52+
k.EnableOutputLayout(DataLayout::bfzyx);
53+
k.EnableInputLayout(DataLayout::bfwzyx);
54+
k.EnableOutputLayout(DataLayout::bfwzyx);
55+
k.EnableTensorOffset();
56+
k.EnableTensorPitches();
57+
k.EnableBatching();
58+
k.EnableDynamicShapesSupport();
59+
return k;
60+
}
61+
62+
JitConstants PermuteKernel_b_f_axes::GetJitConstants(const permute_params& params,
63+
const CommonDispatchData& /*dispatchData*/) const {
64+
auto jit = Parent::GetJitConstants(params, {});
65+
66+
const size_t vec_width = GetVecWidth(params);
67+
const size_t x_size = params.inputs[0].X().v;
68+
const size_t x_tiles = x_size / vec_width;
69+
const size_t x_rem = x_size % vec_width;
70+
71+
jit.AddConstant(MakeJitConstant("VEC_WIDTH", vec_width));
72+
jit.AddConstant(MakeJitConstant("X_TILES", x_tiles));
73+
jit.AddConstant(MakeJitConstant("X_REMAINDER_SIZE", x_rem));
74+
75+
jit.AddConstant(MakeJitConstant("INPUTVTYPE", "CAT(INPUT0_TYPE, VEC_WIDTH)"));
76+
jit.AddConstant(MakeJitConstant("OUTPUTVTYPE", "CAT(OUTPUT_TYPE, VEC_WIDTH)"));
77+
78+
if (!params.fused_ops.empty()) {
79+
std::vector<std::string> output_order;
80+
switch (params.inputs[0].GetDims().size()) {
81+
case 4: output_order = {"b", "f", "y", "x"}; break;
82+
case 5: output_order = {"b", "f", "z", "y", "x"}; break;
83+
case 6: output_order = {"b", "f", "w", "z", "y", "x"}; break;
84+
default: break;
85+
}
86+
FusedOpsConfiguration conf = {"", output_order, "input_var", params.inputs[0].GetDType(), 1};
87+
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
88+
}
89+
90+
return jit;
91+
}
92+
93+
CommonDispatchData PermuteKernel_b_f_axes::SetDefault(const permute_params& params) const {
94+
CommonDispatchData dispatchData;
95+
96+
const auto& in = params.inputs[0];
97+
const auto in_layout = in.GetLayout();
98+
const auto out_layout = params.outputs[0].GetLayout();
99+
const size_t vec_width = GetVecWidth(params);
100+
const size_t x_tiles = CeilDiv(in.X().v, vec_width);
101+
102+
size_t spatial_outer = 1;
103+
if (in.GetDims().size() >= 5) spatial_outer *= in.Z().v;
104+
if (in.GetDims().size() >= 6) spatial_outer *= in.W().v;
105+
106+
// F is looped inside the kernel; GWS[2] covers B only.
107+
dispatchData.gws = {x_tiles, in.Y().v * spatial_outer, in.Batch().v};
108+
109+
const std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws = {
110+
{Tensor::DataChannelName::X},
111+
{Tensor::DataChannelName::Y, Tensor::DataChannelName::Z, Tensor::DataChannelName::W},
112+
{Tensor::DataChannelName::BATCH}};
113+
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo,
114+
in_layout, out_layout, dims_by_gws);
115+
116+
return dispatchData;
117+
}
118+
119+
bool PermuteKernel_b_f_axes::Validate(const Params& p) const {
120+
if (!Parent::Validate(p)) DO_NOT_USE_THIS_KERNEL(p.layerID);
121+
122+
const permute_params& params = static_cast<const permute_params&>(p);
123+
124+
if (params.outputs[0].PitchesDifferFromLogicalDims() || params.inputs[0].PitchesDifferFromLogicalDims())
125+
DO_NOT_USE_THIS_KERNEL(p.layerID);
126+
127+
if (!SimpleLayout(params.inputs[0].GetLayout()))
128+
DO_NOT_USE_THIS_KERNEL(p.layerID);
129+
130+
if (params.inputs[0].GetLayout() != params.outputs[0].GetLayout())
131+
DO_NOT_USE_THIS_KERNEL(p.layerID);
132+
133+
// Only accept [1, 0, 2, ...] — B <-> F swap with spatial axes unchanged.
134+
const auto& order = params.order;
135+
const size_t ndim = order.size();
136+
if (ndim < 3 || ndim > 6)
137+
DO_NOT_USE_THIS_KERNEL(p.layerID);
138+
139+
if (order[0] != 1 || order[1] != 0)
140+
DO_NOT_USE_THIS_KERNEL(p.layerID);
141+
142+
for (size_t i = 2; i < ndim; ++i) {
143+
if (order[i] != static_cast<uint16_t>(i))
144+
DO_NOT_USE_THIS_KERNEL(p.layerID);
145+
}
146+
147+
return true;
148+
}
149+
150+
KernelsPriority PermuteKernel_b_f_axes::GetKernelsPriority(const Params& params) const {
151+
KernelData kd = KernelData::Default<permute_params>(params);
152+
permute_params& newParams = *static_cast<permute_params*>(kd.params.get());
153+
154+
const size_t vec_width = GetVecWidth(newParams);
155+
const size_t x_size = newParams.inputs[0].X().v;
156+
157+
if (x_size >= vec_width * 2 && (x_size % vec_width == 0))
158+
return FORCE_PRIORITY_2;
159+
if (x_size >= vec_width)
160+
return FORCE_PRIORITY_3;
161+
162+
return FORCE_PRIORITY_4;
163+
}
164+
165+
} // namespace kernel_selector
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "permute_kernel_base.h"
8+
9+
namespace kernel_selector {
10+
11+
class PermuteKernel_b_f_axes : public PermuteKernelBase {
12+
public:
13+
using Parent = PermuteKernelBase;
14+
using Parent::Parent;
15+
PermuteKernel_b_f_axes() : PermuteKernelBase("permute_b_f_axes") {}
16+
virtual ~PermuteKernel_b_f_axes() {}
17+
18+
bool Validate(const Params& p) const override;
19+
KernelsPriority GetKernelsPriority(const Params& params) const override;
20+
ParamsKey GetSupportedKey() const override;
21+
22+
protected:
23+
JitConstants GetJitConstants(const permute_params& params, const CommonDispatchData& dispatchData) const override;
24+
CommonDispatchData SetDefault(const permute_params& params) const override;
25+
std::vector<FusedOpType> GetSupportedFusedOps() const override {
26+
return {
27+
FusedOpType::REORDER,
28+
FusedOpType::ACTIVATION,
29+
FusedOpType::QUANTIZE,
30+
FusedOpType::ELTWISE
31+
};
32+
}
33+
};
34+
} // namespace kernel_selector

src/plugins/intel_gpu/src/kernel_selector/kernels/permute/permute_kernel_selector.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "permute_kernel_tile_8x8_4x4_fsv.h"
99
#include "permute_kernel_bfzyx_to_bfyxz.h"
1010
#include "permute_kernel_f_y_axes.h"
11+
#include "permute_kernel_b_f_axes.h"
1112

1213
namespace kernel_selector {
1314

@@ -17,6 +18,7 @@ permute_kernel_selector::permute_kernel_selector() {
1718
Attach<PermuteKernel_tile_8x8_4x4_fsv>();
1819
Attach<PermuteKernel_bfzyx_to_bfyxz>();
1920
Attach<PermuteKernel_f_y_axes>();
21+
Attach<PermuteKernel_b_f_axes>();
2022
}
2123

2224
KernelsData permute_kernel_selector::GetBestKernels(const Params& params) const {

0 commit comments

Comments
 (0)