@@ -15,9 +15,10 @@ template<typename JIT_IMPL>
1515class __attribute__ ((visibility (" default" ))) JITConvImpl {
1616public:
1717 JIT_IMPL jit;
18- KernelLaunchConfig forward_config;
19- KernelLaunchConfig backward_config;
20- KernelLaunchConfig double_backward_config;
18+
19+ KernelLaunchConfig forward_config_ref;
20+ KernelLaunchConfig backward_config_ref;
21+ KernelLaunchConfig double_backward_config_ref;
2122 int opt_level;
2223
2324 JITConvImpl (
@@ -27,25 +28,25 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
2728 KernelLaunchConfig double_backward_config_i,
2829 int opt_level_i) :
2930 jit (jit_kernel),
30- forward_config (forward_config_i),
31- backward_config (backward_config_i),
32- double_backward_config (double_backward_config_i),
31+ forward_config_ref (forward_config_i),
32+ backward_config_ref (backward_config_i),
33+ double_backward_config_ref (double_backward_config_i),
3334 opt_level (opt_level_i) {
3435
3536 vector<string> kernels = {" forward" , " backward" , " fixup_forward" , " fixup_backward" , " double_backward_A" , " double_backward_B" , " fixup_double_backwardB" };
3637 jit.compile (kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);
3738
38- if (forward_config .smem > 0 ) {
39- jit.set_max_smem (0 , forward_config .smem );
40- jit.set_max_smem (4 , forward_config .smem );
39+ if (forward_config_ref .smem > 0 ) {
40+ jit.set_max_smem (0 , forward_config_ref .smem );
41+ jit.set_max_smem (4 , forward_config_ref .smem );
4142 }
4243
43- if (backward_config .smem > 0 ) {
44- jit.set_max_smem (1 , backward_config .smem );
44+ if (backward_config_ref .smem > 0 ) {
45+ jit.set_max_smem (1 , backward_config_ref .smem );
4546 }
4647
47- if (double_backward_config .smem > 0 ) {
48- jit.set_max_smem (5 , double_backward_config .smem );
48+ if (double_backward_config_ref .smem > 0 ) {
49+ jit.set_max_smem (5 , double_backward_config_ref .smem );
4950 }
5051 }
5152
@@ -89,16 +90,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
8990 ConvData conv_data = {rows, cols, nnz, node_count};
9091
9192 void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace};
92- forward_config.hStream = stream;
93- jit.execute (0 , args, forward_config);
93+ jit.execute (0 , args, with_stream (forward_config_ref, stream));
9494
9595 if (reinterpret_cast <uint64_t >(workspace) != 0 ) {
9696 void *fixup_args[] = {&workspace, &L3_out};
9797
98- KernelLaunchConfig fixup_config;
99- fixup_config.num_blocks = forward_config.num_blocks ;
100- fixup_config.num_threads = forward_config.num_threads ;
101- fixup_config.smem = 0 ;
98+ KernelLaunchConfig fixup_config (
99+ forward_config_ref.num_blocks ,
100+ forward_config_ref.num_threads ,
101+ 0
102+ );
102103 fixup_config.hStream = stream;
103104
104105 jit.execute (2 , fixup_args, fixup_config);
@@ -118,16 +119,17 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
118119
119120 ConvData conv_data = {rows, cols, nnz, node_count};
120121 void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm};
121- backward_config.hStream = stream;
122- jit.execute (1 , args, backward_config);
122+ jit.execute (1 , args, with_stream (backward_config_ref, stream));
123123
124124 if (reinterpret_cast <uint64_t >(workspace) != 0 ) {
125125 void *fixup_args[] = {&workspace, &L1_grad};
126126
127- KernelLaunchConfig fixup_config;
128- fixup_config.num_blocks = backward_config.num_blocks ;
129- fixup_config.num_threads = backward_config.num_threads ;
130- fixup_config.smem = 0 ; fixup_config.hStream = stream;
127+ KernelLaunchConfig fixup_config (
128+ backward_config_ref.num_blocks ,
129+ backward_config_ref.num_threads ,
130+ 0
131+ );
132+ fixup_config.hStream = stream;
131133
132134 jit.execute (3 , fixup_args, fixup_config);
133135 }
@@ -147,24 +149,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
147149 &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
148150 &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
149151 };
150- double_backward_config. hStream = stream;
151- jit.execute (4 , args, forward_config );
152+
153+ jit.execute (4 , args, with_stream (forward_config_ref, stream) );
152154 if (reinterpret_cast <uint64_t >(wspace) != 0 ) {
153155 void *fixup_args[] = {&wspace, &L3_dgrad};
154- KernelLaunchConfig fixup_config;
155- fixup_config.num_blocks = forward_config.num_blocks ;
156- fixup_config.num_threads = forward_config.num_threads ;
157- fixup_config.smem = 0 ; fixup_config.hStream = stream;
156+ KernelLaunchConfig fixup_config (
157+ forward_config_ref.num_blocks ,
158+ forward_config_ref.num_threads ,
159+ 0
160+ );
161+ fixup_config.hStream = stream;
158162 jit.execute (2 , fixup_args, fixup_config);
159163 }
160164
161- jit.execute (5 , args, double_backward_config );
165+ jit.execute (5 , args, with_stream (double_backward_config_ref, stream) );
162166 if (reinterpret_cast <uint64_t >(wspace) != 0 ) {
163167 void *fixup_args[] = {&wspace, &L1_grad};
164- KernelLaunchConfig fixup_config;
165- fixup_config.num_blocks = double_backward_config.num_blocks ;
166- fixup_config.num_threads = double_backward_config.num_threads ;
167- fixup_config.smem = 0 ; fixup_config.hStream = stream;
168+ KernelLaunchConfig fixup_config (
169+ double_backward_config_ref.num_blocks ,
170+ double_backward_config_ref.num_threads ,
171+ 0
172+ );
173+ fixup_config.hStream = stream;
168174 jit.execute (6 , fixup_args, fixup_config);
169175 }
170176 }
0 commit comments