Skip to content

Commit 4ae1a02

Browse files
committed
Fix lint
1 parent 7edef03 commit 4ae1a02

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

tests/unit_tests/ssm/test_gated_delta_net.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,29 +149,36 @@ def test_jit_compiled_helpers(self):
149149

150150
qkv_last_dim = (2 * gdn.qk_dim_local_tp + gdn.v_dim_local_tp) // gdn.cp_size
151151
qkv = torch.randn(
152-
batch, seq_len, qkv_last_dim,
153-
device=torch.cuda.current_device(), dtype=torch.bfloat16,
152+
batch, seq_len, qkv_last_dim, device=torch.cuda.current_device(), dtype=torch.bfloat16
154153
)
155154
gate = torch.randn(
156-
batch, seq_len, num_v_heads_local, gdn.value_head_dim,
157-
device=torch.cuda.current_device(), dtype=torch.bfloat16,
155+
batch,
156+
seq_len,
157+
num_v_heads_local,
158+
gdn.value_head_dim,
159+
device=torch.cuda.current_device(),
160+
dtype=torch.bfloat16,
158161
)
159162
beta = torch.randn(
160-
batch, seq_len, num_v_heads_local,
161-
device=torch.cuda.current_device(), dtype=torch.bfloat16,
163+
batch,
164+
seq_len,
165+
num_v_heads_local,
166+
device=torch.cuda.current_device(),
167+
dtype=torch.bfloat16,
162168
)
163169
alpha = torch.randn(
164-
batch, seq_len, num_v_heads_local,
165-
device=torch.cuda.current_device(), dtype=torch.bfloat16,
170+
batch,
171+
seq_len,
172+
num_v_heads_local,
173+
device=torch.cuda.current_device(),
174+
dtype=torch.bfloat16,
166175
)
167176

168177
# Disable dynamo so coverage.py can trace through the method bodies,
169178
# which are normally wrapped by @jit_fuser (torch.compile).
170179
with torch._dynamo.config.patch(disable=True):
171180
query, key, value, gate_out, beta_out, alpha_out = (
172-
gdn._prepare_qkv_for_gated_delta_rule(
173-
qkv, gate, beta, alpha, batch, seq_len
174-
)
181+
gdn._prepare_qkv_for_gated_delta_rule(qkv, gate, beta, alpha, batch, seq_len)
175182
)
176183

177184
assert query.shape == (batch, seq_len, num_v_heads_local, gdn.key_head_dim)
@@ -182,16 +189,14 @@ def test_jit_compiled_helpers(self):
182189
assert value.is_contiguous()
183190

184191
A_log_mock = torch.randn(
185-
num_v_heads_local, device=torch.cuda.current_device(), dtype=torch.bfloat16,
192+
num_v_heads_local, device=torch.cuda.current_device(), dtype=torch.bfloat16
186193
)
187194
dt_bias_mock = torch.randn(
188-
num_v_heads_local, device=torch.cuda.current_device(), dtype=torch.bfloat16,
195+
num_v_heads_local, device=torch.cuda.current_device(), dtype=torch.bfloat16
189196
)
190197

191198
with torch._dynamo.config.patch(disable=True):
192-
g, beta_sig = gdn._compute_g_and_beta(
193-
A_log_mock, dt_bias_mock, alpha, beta
194-
)
199+
g, beta_sig = gdn._compute_g_and_beta(A_log_mock, dt_bias_mock, alpha, beta)
195200

196201
assert g.dtype == torch.float32
197202
assert g.shape == alpha.shape

0 commit comments

Comments
 (0)