@@ -134,6 +134,7 @@ def chunk_gsa_fwd_k_kernel_intra(
134134 else :
135135 bos , eos = i_b * T , i_b * T + T
136136
137+ o_i = tl .arange (0 , BC )
137138 o_v = i_v * BV + tl .arange (0 , BV )
138139 m_v = o_v < V
139140
@@ -161,7 +162,6 @@ def chunk_gsa_fwd_k_kernel_intra(
161162 b_g = tl .load (p_g , boundary_check = (0 , 1 ))
162163 b_o *= exp (b_g - b_gn [None , :])
163164
164- o_i = tl .arange (0 , BC )
165165 o_A = (bos + i_t * BT + i_i * BC + tl .arange (0 , BC )) * HQ * BT + i_hq * BT + i_i * BC
166166 m_A = (i_t * BT + i_i * BC + tl .arange (0 , BC )) < T
167167 for j in range (0 , min (BC , T - i_t * BT - i_i * BC )):
@@ -247,13 +247,13 @@ def chunk_gsa_bwd_k_kernel_dA(
247247 # [BC, BV]
248248 b_g = tl .load (p_g , boundary_check = (0 , 1 ))
249249 b_do = tl .load (p_do , boundary_check = (0 , 1 ))
250- b_do = (b_do * exp (b_g - b_gn [None , :]) * scale ).to (b_do .dtype )
250+ b_do = (b_do * exp (b_g - b_gn [None , :])).to (b_do .dtype )
251251 # [BV, BC]
252252 b_v = tl .load (p_v , boundary_check = (0 , 1 ))
253253 b_gv = tl .load (p_gv , boundary_check = (0 , 1 ))
254254 b_vg = (b_v * exp (b_gn [:, None ] - b_gv )).to (b_v .dtype )
255255 # [BC, BC]
256- b_dA = tl .dot (b_do , b_vg )
256+ b_dA = tl .dot (b_do , b_vg ) * scale
257257 elif i_i == i_j :
258258 p_g = tl .make_block_ptr (g + (bos * H + i_h ) * V , (T , V ), (H * V , 1 ), (i_t * BT + i_i * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
259259 p_do = tl .make_block_ptr (do + (bos * HQ + i_hq ) * V , (T , V ), (HQ * V , 1 ), (i_t * BT + i_i * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
@@ -381,15 +381,15 @@ def chunk_gsa_bwd_k_kernel_dqkvg(
381381 b_h = tl .load (p_h , boundary_check = (0 , 1 ))
382382 # [BT, BV]
383383 b_do = tl .load (p_do , boundary_check = (0 , 1 ))
384- b_do = (b_do * exp (b_g ) * scale ).to (b_do .dtype )
384+ b_do = (b_do * exp (b_g )).to (b_do .dtype )
385385 # [BK, BV]
386386 b_dh = tl .load (p_dh , boundary_check = (0 , 1 ))
387387 # [BV]
388388 b_dg = tl .sum (tl .trans (b_h ) * b_dh , 0 ) * exp (b_gn )
389389
390390 b_dh = b_dh .to (b_k .dtype )
391391 # [BT, BK]
392- b_dq += tl .dot (b_do , b_h .to (b_k .dtype ))
392+ b_dq += tl .dot (b_do , b_h .to (b_k .dtype )) * scale
393393 b_dk += tl .dot ((b_v * b_gv ).to (b_v .dtype ), tl .trans (b_dh ))
394394 # [BT, BV]
395395 b_dv = tl .dot (b_k , b_dh ) * b_gv
@@ -452,6 +452,7 @@ def chunk_gsa_bwd_k_kernel_intra_dvg(
452452 else :
453453 bos , eos = i_b * T , i_b * T + T
454454
455+ o_i = tl .arange (0 , BC )
455456 o_v = i_v * BV + tl .arange (0 , BV )
456457 m_v = o_v < V
457458
@@ -469,20 +470,19 @@ def chunk_gsa_bwd_k_kernel_intra_dvg(
469470 p_g = tl .make_block_ptr (g + (bos * H + i_h ) * V , (T , V ), (H * V , 1 ), (i_t * BT + i_j * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
470471 p_A = tl .make_block_ptr (A + (bos * HQ + i_hq ) * BT , (BT , T ), (1 , HQ * BT ), (i_i * BC , i_t * BT + i_j * BC ), (BC , BC ), (0 , 1 ))
471472 p_do = tl .make_block_ptr (do + (bos * HQ + i_hq ) * V , (T , V ), (HQ * V , 1 ), (i_t * BT + i_j * BC , i_v * BV ), (BC , BV ), (1 , 0 ))
473+
474+ m_j = (i_t * BT + i_j * BC + o_i ) < T
472475 # [BC, BV]
473476 b_g = tl .load (p_g , boundary_check = (0 , 1 ))
474- b_do = tl .load (p_do , boundary_check = (0 , 1 )) * exp (b_g - b_gn [None , :])
477+ b_do = tl .load (p_do , boundary_check = (0 , 1 )) * tl . where ( m_j [:, None ], exp (b_g - b_gn [None , :]), 0 )
475478 # [BC, BC]
476479 b_A = tl .load (p_A , boundary_check = (0 , 1 ))
477480 # [BC, BV]
478481 b_dv += tl .dot (b_A , b_do .to (b_A .dtype ))
479482 b_dv *= exp (b_gn [None , :] - b_gv )
480483
481- o_i = tl .arange (0 , BC )
482- o_c = i_i * BC + tl .arange (0 , BC )
483-
484484 p_g = g + (bos + i_t * BT + i_i * BC ) * H * V + i_h * V + o_v
485- p_A = A + (bos + i_t * BT + i_i * BC ) * HQ * BT + i_hq * BT + o_c
485+ p_A = A + (bos + i_t * BT + i_i * BC ) * HQ * BT + i_hq * BT + i_i * BC + o_i
486486 p_do = do + (bos + i_t * BT + i_i * BC ) * HQ * V + i_hq * V + o_v
487487 for j in range (0 , min (BC , T - i_t * BT - i_i * BC )):
488488 # [BC,]
0 commit comments