Skip to content

Commit ca910f8

Browse files
authored
[Model] Unify cache function (#777)
1 parent 5da31d1 commit ca910f8

33 files changed

+1123
-1141
lines changed

fla/layers/abc.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn as nn
1010
from einops import rearrange
1111

12+
from fla.layers.utils import get_layer_cache, update_layer_cache
1213
from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
1314
from fla.modules.activations import swiglu, swish
1415
from fla.ops.abc.chunk import chunk_abc
@@ -146,9 +147,7 @@ def forward(
146147
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
147148
)
148149

149-
last_state = None
150-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
151-
last_state = past_key_values[self.layer_idx]
150+
last_state = get_layer_cache(self, past_key_values)
152151

153152
cu_seqlens = kwargs.get('cu_seqlens')
154153
if cu_seqlens is not None:
@@ -210,13 +209,13 @@ def forward(
210209
initial_state=recurrent_state,
211210
output_final_state=use_cache,
212211
)
213-
if past_key_values is not None:
214-
past_key_values.update(
215-
recurrent_state=recurrent_state,
216-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
217-
layer_idx=self.layer_idx,
218-
offset=q.shape[1],
219-
)
212+
update_layer_cache(
213+
self,
214+
past_key_values,
215+
recurrent_state=recurrent_state,
216+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
217+
offset=q.shape[1],
218+
)
220219

221220
if self.use_norm and not self.use_output_gate:
222221
o = self.g_norm(o)

fla/layers/comba.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from einops import rearrange, repeat
1212
from torch.nn import functional as F
1313

14-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
14+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1515
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
1616
from fla.ops.comba import chunk_comba, fused_recurrent_comba
1717

@@ -225,9 +225,7 @@ def forward(
225225
mode = 'fused_recurrent' if (q_len <= 64 and not self.training) else self.mode
226226
if self.training:
227227
assert mode == 'chunk', "Only chunk mode is supported in training."
228-
last_state = None
229-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
230-
last_state = past_key_values[self.layer_idx]
228+
last_state = get_layer_cache(self, past_key_values)
231229

232230
cu_seqlens = kwargs.get('cu_seqlens')
233231
if attention_mask is not None:
@@ -309,13 +307,13 @@ def forward(
309307
else:
310308
raise NotImplementedError(f"Not supported mode `{mode}`.")
311309

312-
if past_key_values is not None:
313-
past_key_values.update(
314-
recurrent_state=recurrent_state,
315-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
316-
layer_idx=self.layer_idx,
317-
offset=q_len,
318-
)
310+
update_layer_cache(
311+
self,
312+
past_key_values,
313+
recurrent_state=recurrent_state,
314+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
315+
offset=q_len,
316+
)
319317

320318
if self.use_output_gate:
321319
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)

fla/layers/delta_net.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from einops import rearrange
1111
from torch.nn import functional as F
1212

13-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
13+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1414
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
1515
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
1616

@@ -182,9 +182,7 @@ def forward(
182182
# change to inference mode.
183183
mode = 'fused_recurrent' if q_len <= 64 else self.mode
184184

185-
last_state = None
186-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
187-
last_state = past_key_values[self.layer_idx]
185+
last_state = get_layer_cache(self, past_key_values)
188186

189187
cu_seqlens = kwargs.get('cu_seqlens')
190188
if attention_mask is not None:
@@ -268,13 +266,13 @@ def forward(
268266
else:
269267
raise NotImplementedError(f"Not supported mode `{mode}`.")
270268

271-
if past_key_values is not None:
272-
past_key_values.update(
273-
recurrent_state=recurrent_state,
274-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
275-
layer_idx=self.layer_idx,
276-
offset=q_len,
277-
)
269+
update_layer_cache(
270+
self,
271+
past_key_values,
272+
recurrent_state=recurrent_state,
273+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274+
offset=q_len,
275+
)
278276

279277
if self.use_gate:
280278
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)

fla/layers/gated_deltanet.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from einops import rearrange, repeat
1212
from torch.nn import functional as F
1313

14-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
14+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1515
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
1616
from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
1717

@@ -221,9 +221,7 @@ def forward(
221221
if self.training:
222222
assert mode == 'chunk', "Only chunk mode is supported in training."
223223

224-
last_state = None
225-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
226-
last_state = past_key_values[self.layer_idx]
224+
last_state = get_layer_cache(self, past_key_values)
227225

228226
cu_seqlens = kwargs.get('cu_seqlens')
229227
if attention_mask is not None:
@@ -297,13 +295,13 @@ def forward(
297295
else:
298296
raise NotImplementedError(f"Not supported mode `{mode}`.")
299297

300-
if past_key_values is not None:
301-
past_key_values.update(
302-
recurrent_state=recurrent_state,
303-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
304-
layer_idx=self.layer_idx,
305-
offset=q_len,
306-
)
298+
update_layer_cache(
299+
self,
300+
past_key_values,
301+
recurrent_state=recurrent_state,
302+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
303+
offset=q_len,
304+
)
307305

308306
if self.use_gate:
309307
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)

fla/layers/gated_deltaproduct.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from einops import rearrange, repeat
1212
from torch.nn import functional as F
1313

14-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
14+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1515
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
1616
from fla.ops.gated_delta_product import chunk_gated_delta_product
1717
from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
@@ -178,9 +178,7 @@ def forward(
178178
if self.training:
179179
assert mode == 'chunk', "Only chunk mode is supported in training."
180180

181-
last_state = None
182-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
183-
last_state = past_key_values[self.layer_idx]
181+
last_state = get_layer_cache(self, past_key_values)
184182

185183
cu_seqlens = kwargs.get('cu_seqlens')
186184
if attention_mask is not None:
@@ -268,13 +266,13 @@ def forward(
268266
)
269267
o = rearrange(o, '... (t n) h d -> ... t n h d', n=self.num_householder)[..., -1, :, :].contiguous()
270268

271-
if past_key_values is not None:
272-
past_key_values.update(
273-
recurrent_state=recurrent_state,
274-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
275-
layer_idx=self.layer_idx,
276-
offset=q_len,
277-
)
269+
update_layer_cache(
270+
self,
271+
past_key_values,
272+
recurrent_state=recurrent_state,
273+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274+
offset=q_len,
275+
)
278276

279277
if self.use_output_gate:
280278
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)

fla/layers/gla.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn.functional as F
1111
from einops import rearrange, repeat
1212

13-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
13+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1414
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
1515
from fla.modules.activations import ACT2FN
1616
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
@@ -189,9 +189,7 @@ def forward(
189189
batch_size, q_len, _ = hidden_states.shape
190190
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
191191

192-
last_state = None
193-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
194-
last_state = past_key_values[self.layer_idx]
192+
last_state = get_layer_cache(self, past_key_values)
195193

196194
cu_seqlens = kwargs.get('cu_seqlens')
197195
if attention_mask is not None:
@@ -274,13 +272,13 @@ def forward(
274272
else:
275273
raise NotImplementedError(f"Not supported mode `{mode}`.")
276274

277-
if past_key_values is not None:
278-
past_key_values.update(
279-
recurrent_state=recurrent_state,
280-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281-
layer_idx=self.layer_idx,
282-
offset=q_len,
283-
)
275+
update_layer_cache(
276+
self,
277+
past_key_values,
278+
recurrent_state=recurrent_state,
279+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
280+
offset=q_len,
281+
)
284282

285283
if self.use_output_gate:
286284
g = self.g_proj(hidden_states)

fla/layers/gsa.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn.functional as F
1111
from einops import rearrange, repeat
1212

13-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
13+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1414
from fla.modules import RMSNorm, ShortConvolution
1515
from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
1616
from fla.modules.layernorm import rms_norm_linear
@@ -143,9 +143,7 @@ def forward(
143143
batch_size, q_len, _ = hidden_states.shape
144144
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
145145

146-
last_state = None
147-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
148-
last_state = past_key_values[self.layer_idx]
146+
last_state = get_layer_cache(self, past_key_values)
149147

150148
cu_seqlens = kwargs.get('cu_seqlens')
151149
if attention_mask is not None:
@@ -223,13 +221,13 @@ def forward(
223221
else:
224222
raise NotImplementedError(f"Not supported mode `{mode}`.")
225223

226-
if past_key_values is not None:
227-
past_key_values.update(
228-
recurrent_state=recurrent_state,
229-
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
230-
layer_idx=self.layer_idx,
231-
offset=q_len,
232-
)
224+
update_layer_cache(
225+
self,
226+
past_key_values,
227+
recurrent_state=recurrent_state,
228+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
229+
offset=q_len,
230+
)
233231

234232
o = rearrange(o, '... h d -> ... (h d)')
235233
o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)

fla/layers/hgrn.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13+
from fla.layers.utils import get_layer_cache, update_layer_cache
1314
from fla.modules import FusedRMSNormGated, ShortConvolution
1415
from fla.modules.activations import swiglu
1516
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
@@ -95,9 +96,7 @@ def forward(
9596
# launching the triton kernel for just one token will actually be slower
9697
mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
9798

98-
last_state = None
99-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
100-
last_state = past_key_values[self.layer_idx]
99+
last_state = get_layer_cache(self, past_key_values)
101100

102101
cu_seqlens = kwargs.get('cu_seqlens')
103102
if self.use_short_conv:
@@ -154,13 +153,13 @@ def forward(
154153
else:
155154
raise NotImplementedError(f"Not supported mode `{mode}`.")
156155

157-
if past_key_values is not None:
158-
past_key_values.update(
159-
recurrent_state=recurrent_state,
160-
conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
161-
layer_idx=self.layer_idx,
162-
offset=i.shape[2],
163-
)
156+
update_layer_cache(
157+
self,
158+
past_key_values,
159+
recurrent_state=recurrent_state,
160+
conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
161+
offset=i.shape[1],
162+
)
164163

165164
o = self.g_norm(o, self.g_proj(hidden_states))
166165
o = self.o_proj(o)

fla/layers/hgrn2.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from einops import rearrange
1313

14-
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
14+
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
1515
from fla.modules import RMSNorm, ShortConvolution
1616
from fla.modules.activations import swish
1717
from fla.modules.layernorm import rms_norm_linear
@@ -92,7 +92,8 @@ def __init__(
9292
activation=None,
9393
)
9494

95-
self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps, dtype=torch.float32)
95+
self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine,
96+
eps=norm_eps, dtype=torch.float32)
9697
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
9798

9899
def forward(
@@ -115,9 +116,7 @@ def forward(
115116
batch_size, q_len, _ = hidden_states.shape
116117
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
117118

118-
last_state = None
119-
if past_key_values is not None and len(past_key_values) > self.layer_idx:
120-
last_state = past_key_values[self.layer_idx]
119+
last_state = get_layer_cache(self, past_key_values)
121120

122121
cu_seqlens = kwargs.get('cu_seqlens')
123122
if attention_mask is not None:
@@ -195,13 +194,13 @@ def forward(
195194
else:
196195
raise NotImplementedError(f"Not supported mode `{mode}`.")
197196

198-
if past_key_values is not None:
199-
past_key_values.update(
200-
recurrent_state=recurrent_state,
201-
conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
202-
layer_idx=self.layer_idx,
203-
offset=q_len,
204-
)
197+
update_layer_cache(
198+
self,
199+
past_key_values,
200+
recurrent_state=recurrent_state,
201+
conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
202+
offset=q_len,
203+
)
205204

206205
o = rearrange(o, '... h d -> ... (h d)')
207206
o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)

0 commit comments

Comments
 (0)