1111import torch .nn .functional as F
1212from einops import rearrange
1313
14- from fla .layers .utils import get_unpad_data , index_first_axis , pad_input
14+ from fla .layers .utils import get_layer_cache , get_unpad_data , index_first_axis , pad_input , update_layer_cache
1515from fla .modules import RMSNorm , ShortConvolution
1616from fla .modules .activations import swish
1717from fla .modules .layernorm import rms_norm_linear
@@ -92,7 +92,8 @@ def __init__(
9292 activation = None ,
9393 )
9494
95- self .g_norm = RMSNorm (hidden_size = self .hidden_size , elementwise_affine = elementwise_affine , eps = norm_eps , dtype = torch .float32 )
95+ self .g_norm = RMSNorm (hidden_size = self .hidden_size , elementwise_affine = elementwise_affine ,
96+ eps = norm_eps , dtype = torch .float32 )
9697 self .o_proj = nn .Linear (self .input_dim , hidden_size , bias = False )
9798
9899 def forward (
@@ -115,9 +116,7 @@ def forward(
115116 batch_size , q_len , _ = hidden_states .shape
116117 mode = 'fused_recurrent' if hidden_states .shape [1 ] <= 64 else self .mode
117118
118- last_state = None
119- if past_key_values is not None and len (past_key_values ) > self .layer_idx :
120- last_state = past_key_values [self .layer_idx ]
119+ last_state = get_layer_cache (self , past_key_values )
121120
122121 cu_seqlens = kwargs .get ('cu_seqlens' )
123122 if attention_mask is not None :
@@ -195,13 +194,13 @@ def forward(
195194 else :
196195 raise NotImplementedError (f"Not supported mode `{ mode } `." )
197196
198- if past_key_values is not None :
199- past_key_values . update (
200- recurrent_state = recurrent_state ,
201- conv_state = ( conv_state_q , conv_state_f , conv_state_i ) if self . use_short_conv else None ,
202- layer_idx = self .layer_idx ,
203- offset = q_len ,
204- )
197+ update_layer_cache (
198+ self ,
199+ past_key_values ,
200+ recurrent_state = recurrent_state ,
201+ conv_state = ( conv_state_q , conv_state_f , conv_state_i ) if self .use_short_conv else None ,
202+ offset = q_len ,
203+ )
205204
206205 o = rearrange (o , '... h d -> ... (h d)' )
207206 o = rms_norm_linear (o , self .g_norm .weight , self .g_norm .bias , self .o_proj .weight , self .o_proj .bias )
0 commit comments