Skip to content

Commit 3a904f0

Browse files
authored
[GSA] Fix gate oob bugs (#694)
Refactor calculations involving 'b_do' and 'b_A' to improve clarity and correctness.
1 parent 45fc45d commit 3a904f0

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

fla/ops/gsa/chunk.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def chunk_gsa_fwd_k_kernel_intra(
134134
else:
135135
bos, eos = i_b * T, i_b * T + T
136136

137+
o_i = tl.arange(0, BC)
137138
o_v = i_v * BV + tl.arange(0, BV)
138139
m_v = o_v < V
139140

@@ -161,7 +162,6 @@ def chunk_gsa_fwd_k_kernel_intra(
161162
b_g = tl.load(p_g, boundary_check=(0, 1))
162163
b_o *= exp(b_g - b_gn[None, :])
163164

164-
o_i = tl.arange(0, BC)
165165
o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
166166
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
167167
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
@@ -247,13 +247,13 @@ def chunk_gsa_bwd_k_kernel_dA(
247247
# [BC, BV]
248248
b_g = tl.load(p_g, boundary_check=(0, 1))
249249
b_do = tl.load(p_do, boundary_check=(0, 1))
250-
b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
250+
b_do = (b_do * exp(b_g - b_gn[None, :])).to(b_do.dtype)
251251
# [BV, BC]
252252
b_v = tl.load(p_v, boundary_check=(0, 1))
253253
b_gv = tl.load(p_gv, boundary_check=(0, 1))
254254
b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
255255
# [BC, BC]
256-
b_dA = tl.dot(b_do, b_vg)
256+
b_dA = tl.dot(b_do, b_vg) * scale
257257
elif i_i == i_j:
258258
p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
259259
p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
@@ -381,15 +381,15 @@ def chunk_gsa_bwd_k_kernel_dqkvg(
381381
b_h = tl.load(p_h, boundary_check=(0, 1))
382382
# [BT, BV]
383383
b_do = tl.load(p_do, boundary_check=(0, 1))
384-
b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
384+
b_do = (b_do * exp(b_g)).to(b_do.dtype)
385385
# [BK, BV]
386386
b_dh = tl.load(p_dh, boundary_check=(0, 1))
387387
# [BV]
388388
b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
389389

390390
b_dh = b_dh.to(b_k.dtype)
391391
# [BT, BK]
392-
b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
392+
b_dq += tl.dot(b_do, b_h.to(b_k.dtype)) * scale
393393
b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
394394
# [BT, BV]
395395
b_dv = tl.dot(b_k, b_dh) * b_gv
@@ -452,6 +452,7 @@ def chunk_gsa_bwd_k_kernel_intra_dvg(
452452
else:
453453
bos, eos = i_b * T, i_b * T + T
454454

455+
o_i = tl.arange(0, BC)
455456
o_v = i_v * BV + tl.arange(0, BV)
456457
m_v = o_v < V
457458

@@ -469,20 +470,19 @@ def chunk_gsa_bwd_k_kernel_intra_dvg(
469470
p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
470471
p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
471472
p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
473+
474+
m_j = (i_t * BT + i_j * BC + o_i) < T
472475
# [BC, BV]
473476
b_g = tl.load(p_g, boundary_check=(0, 1))
474-
b_do = tl.load(p_do, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
477+
b_do = tl.load(p_do, boundary_check=(0, 1)) * tl.where(m_j[:, None], exp(b_g - b_gn[None, :]), 0)
475478
# [BC, BC]
476479
b_A = tl.load(p_A, boundary_check=(0, 1))
477480
# [BC, BV]
478481
b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
479482
b_dv *= exp(b_gn[None, :] - b_gv)
480483

481-
o_i = tl.arange(0, BC)
482-
o_c = i_i * BC + tl.arange(0, BC)
483-
484484
p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
485-
p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
485+
p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + i_i * BC + o_i
486486
p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
487487
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
488488
# [BC,]

0 commit comments

Comments
 (0)