@@ -149,29 +149,36 @@ def test_jit_compiled_helpers(self):
149149
150150 qkv_last_dim = (2 * gdn .qk_dim_local_tp + gdn .v_dim_local_tp ) // gdn .cp_size
151151 qkv = torch .randn (
152- batch , seq_len , qkv_last_dim ,
153- device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
152+ batch , seq_len , qkv_last_dim , device = torch .cuda .current_device (), dtype = torch .bfloat16
154153 )
155154 gate = torch .randn (
156- batch , seq_len , num_v_heads_local , gdn .value_head_dim ,
157- device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
155+ batch ,
156+ seq_len ,
157+ num_v_heads_local ,
158+ gdn .value_head_dim ,
159+ device = torch .cuda .current_device (),
160+ dtype = torch .bfloat16 ,
158161 )
159162 beta = torch .randn (
160- batch , seq_len , num_v_heads_local ,
161- device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
163+ batch ,
164+ seq_len ,
165+ num_v_heads_local ,
166+ device = torch .cuda .current_device (),
167+ dtype = torch .bfloat16 ,
162168 )
163169 alpha = torch .randn (
164- batch , seq_len , num_v_heads_local ,
165- device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
170+ batch ,
171+ seq_len ,
172+ num_v_heads_local ,
173+ device = torch .cuda .current_device (),
174+ dtype = torch .bfloat16 ,
166175 )
167176
168177 # Disable dynamo so coverage.py can trace through the method bodies,
169178 # which are normally wrapped by @jit_fuser (torch.compile).
170179 with torch ._dynamo .config .patch (disable = True ):
171180 query , key , value , gate_out , beta_out , alpha_out = (
172- gdn ._prepare_qkv_for_gated_delta_rule (
173- qkv , gate , beta , alpha , batch , seq_len
174- )
181+ gdn ._prepare_qkv_for_gated_delta_rule (qkv , gate , beta , alpha , batch , seq_len )
175182 )
176183
177184 assert query .shape == (batch , seq_len , num_v_heads_local , gdn .key_head_dim )
@@ -182,16 +189,14 @@ def test_jit_compiled_helpers(self):
182189 assert value .is_contiguous ()
183190
184191 A_log_mock = torch .randn (
185- num_v_heads_local , device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
192+ num_v_heads_local , device = torch .cuda .current_device (), dtype = torch .bfloat16
186193 )
187194 dt_bias_mock = torch .randn (
188- num_v_heads_local , device = torch .cuda .current_device (), dtype = torch .bfloat16 ,
195+ num_v_heads_local , device = torch .cuda .current_device (), dtype = torch .bfloat16
189196 )
190197
191198 with torch ._dynamo .config .patch (disable = True ):
192- g , beta_sig = gdn ._compute_g_and_beta (
193- A_log_mock , dt_bias_mock , alpha , beta
194- )
199+ g , beta_sig = gdn ._compute_g_and_beta (A_log_mock , dt_bias_mock , alpha , beta )
195200
196201 assert g .dtype == torch .float32
197202 assert g .shape == alpha .shape
0 commit comments