@@ -22,53 +22,37 @@ int BatchNorm_vulkan::create_pipeline(const Option& opt)
2222
2323 int elempack = channels % 4 == 0 ? 4 : 1 ;
2424
25- size_t elemsize;
26- if (opt.use_fp16_storage || opt.use_fp16_packed || opt.use_bf16_storage || opt.use_bf16_packed )
27- {
28- elemsize = elempack * 2u ;
29- }
30- else
31- {
32- elemsize = elempack * 4u ;
33- }
34-
35- Mat shape_packed;
36- if (shape.dims == 1 ) shape_packed = Mat (shape.w / elempack, (void *)0 , elemsize, elempack);
37- if (shape.dims == 2 ) shape_packed = Mat (shape.w , shape.h / elempack, (void *)0 , elemsize, elempack);
38- if (shape.dims == 3 ) shape_packed = Mat (shape.w , shape.h , shape.c / elempack, (void *)0 , elemsize, elempack);
39- if (shape.dims == 4 ) shape_packed = Mat (shape.w , shape.h , shape.d , shape.c / elempack, (void *)0 , elemsize, elempack);
40-
4125 std::vector<vk_specialization_type> specializations (0 + 5 );
42- specializations[0 + 0 ].i = std::min (3 , shape_packed .dims );
43- specializations[0 + 1 ].i = shape_packed .w ;
44- specializations[0 + 2 ].i = shape_packed .h * shape_packed .d ;
45- specializations[0 + 3 ].i = shape_packed .c ;
46- specializations[0 + 4 ].i = shape_packed .cstep ;
26+ specializations[0 + 0 ].i = std::min (3 , shape .dims );
27+ specializations[0 + 1 ].i = shape .w ;
28+ specializations[0 + 2 ].i = shape .h * shape .d ;
29+ specializations[0 + 3 ].i = shape .c ;
30+ specializations[0 + 4 ].i = shape .cstep ;
4731
4832 Mat local_size_xyz (4 , 4 , std::min (4 , channels / elempack), (void *)0 );
49- if (shape_packed .dims == 1 )
33+ if (shape .dims == 1 )
5034 {
51- local_size_xyz.w = std::min (64 , shape_packed .w );
35+ local_size_xyz.w = std::min (64 , shape .w );
5236 local_size_xyz.h = 1 ;
5337 local_size_xyz.c = 1 ;
5438 }
55- if (shape_packed .dims == 2 )
39+ if (shape .dims == 2 )
5640 {
57- local_size_xyz.w = std::min (8 , shape_packed .w );
58- local_size_xyz.h = std::min (8 , shape_packed .h );
41+ local_size_xyz.w = std::min (8 , shape .w );
42+ local_size_xyz.h = std::min (8 , shape .h );
5943 local_size_xyz.c = 1 ;
6044 }
61- if (shape_packed .dims == 3 )
45+ if (shape .dims == 3 )
6246 {
63- local_size_xyz.w = std::min (4 , shape_packed .w );
64- local_size_xyz.h = std::min (4 , shape_packed .h );
65- local_size_xyz.c = std::min (4 , shape_packed .c );
47+ local_size_xyz.w = std::min (4 , shape .w );
48+ local_size_xyz.h = std::min (4 , shape .h );
49+ local_size_xyz.c = std::min (4 , shape .c );
6650 }
67- if (shape_packed .dims == 4 )
51+ if (shape .dims == 4 )
6852 {
69- local_size_xyz.w = std::min (4 , shape_packed .w );
70- local_size_xyz.h = std::min (4 , shape_packed .h * shape_packed .d );
71- local_size_xyz.c = std::min (4 , shape_packed .c );
53+ local_size_xyz.w = std::min (4 , shape .w );
54+ local_size_xyz.h = std::min (4 , shape .h * shape .d );
55+ local_size_xyz.c = std::min (4 , shape .c );
7256 }
7357
7458 // pack1
0 commit comments