Skip to content

Commit b1a08ae

Browse files
authored
packed shape hint for vulkan layers (#6553)
1 parent 22ef29f commit b1a08ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1319
-2221
lines changed

src/layer/vulkan/absval_vulkan.cpp

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,8 @@ int AbsVal_vulkan::create_pipeline(const Option& opt)
1919
{
2020
const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0];
2121

22-
const int dims = shape.dims;
23-
24-
int elempack = 0;
25-
if (dims == 1) elempack = shape.w % 4 == 0 ? 4 : 1;
26-
if (dims == 2) elempack = shape.h % 4 == 0 ? 4 : 1;
27-
if (dims == 3 || dims == 4) elempack = shape.c % 4 == 0 ? 4 : 1;
28-
29-
size_t elemsize;
30-
if (opt.use_fp16_storage || opt.use_fp16_packed || opt.use_bf16_storage || opt.use_bf16_packed)
31-
{
32-
elemsize = elempack * 2u;
33-
}
34-
else
35-
{
36-
elemsize = elempack * 4u;
37-
}
38-
39-
Mat shape_packed;
40-
if (dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack);
41-
if (dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack);
42-
if (dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack);
43-
if (dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack);
44-
4522
std::vector<vk_specialization_type> specializations(1);
46-
specializations[0].u32 = shape_packed.total() * elempack / 4;
23+
specializations[0].u32 = shape.total() * shape.elempack / 4;
4724

4825
const int local_size_x = vkdev->info.subgroup_size();
4926

src/layer/vulkan/batchnorm_vulkan.cpp

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,37 @@ int BatchNorm_vulkan::create_pipeline(const Option& opt)
2222

2323
int elempack = channels % 4 == 0 ? 4 : 1;
2424

25-
size_t elemsize;
26-
if (opt.use_fp16_storage || opt.use_fp16_packed || opt.use_bf16_storage || opt.use_bf16_packed)
27-
{
28-
elemsize = elempack * 2u;
29-
}
30-
else
31-
{
32-
elemsize = elempack * 4u;
33-
}
34-
35-
Mat shape_packed;
36-
if (shape.dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack);
37-
if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack);
38-
if (shape.dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack);
39-
if (shape.dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack);
40-
4125
std::vector<vk_specialization_type> specializations(0 + 5);
42-
specializations[0 + 0].i = std::min(3, shape_packed.dims);
43-
specializations[0 + 1].i = shape_packed.w;
44-
specializations[0 + 2].i = shape_packed.h * shape_packed.d;
45-
specializations[0 + 3].i = shape_packed.c;
46-
specializations[0 + 4].i = shape_packed.cstep;
26+
specializations[0 + 0].i = std::min(3, shape.dims);
27+
specializations[0 + 1].i = shape.w;
28+
specializations[0 + 2].i = shape.h * shape.d;
29+
specializations[0 + 3].i = shape.c;
30+
specializations[0 + 4].i = shape.cstep;
4731

4832
Mat local_size_xyz(4, 4, std::min(4, channels / elempack), (void*)0);
49-
if (shape_packed.dims == 1)
33+
if (shape.dims == 1)
5034
{
51-
local_size_xyz.w = std::min(64, shape_packed.w);
35+
local_size_xyz.w = std::min(64, shape.w);
5236
local_size_xyz.h = 1;
5337
local_size_xyz.c = 1;
5438
}
55-
if (shape_packed.dims == 2)
39+
if (shape.dims == 2)
5640
{
57-
local_size_xyz.w = std::min(8, shape_packed.w);
58-
local_size_xyz.h = std::min(8, shape_packed.h);
41+
local_size_xyz.w = std::min(8, shape.w);
42+
local_size_xyz.h = std::min(8, shape.h);
5943
local_size_xyz.c = 1;
6044
}
61-
if (shape_packed.dims == 3)
45+
if (shape.dims == 3)
6246
{
63-
local_size_xyz.w = std::min(4, shape_packed.w);
64-
local_size_xyz.h = std::min(4, shape_packed.h);
65-
local_size_xyz.c = std::min(4, shape_packed.c);
47+
local_size_xyz.w = std::min(4, shape.w);
48+
local_size_xyz.h = std::min(4, shape.h);
49+
local_size_xyz.c = std::min(4, shape.c);
6650
}
67-
if (shape_packed.dims == 4)
51+
if (shape.dims == 4)
6852
{
69-
local_size_xyz.w = std::min(4, shape_packed.w);
70-
local_size_xyz.h = std::min(4, shape_packed.h * shape_packed.d);
71-
local_size_xyz.c = std::min(4, shape_packed.c);
53+
local_size_xyz.w = std::min(4, shape.w);
54+
local_size_xyz.h = std::min(4, shape.h * shape.d);
55+
local_size_xyz.c = std::min(4, shape.c);
7256
}
7357

7458
// pack1

0 commit comments

Comments
 (0)