Skip to content

Commit 773d581

Browse files
committed
add tests for mixtral kv-cache generation
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 31148bb commit 773d581

File tree

1 file changed

+144
-1
lines changed

1 file changed

+144
-1
lines changed

bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
from collator import DataCollatorWithFlattening
3939
from convert import convert_mixtral_hf_to_te, convert_mixtral_te_to_hf
40-
from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM
40+
from modeling_mixtral_te import HFInferenceParams, NVMixtralConfig, NVMixtralForCausalLM
4141
from tests.common import BaseModelTest, TestTolerances
4242

4343

@@ -145,3 +145,146 @@ def get_tolerances(self) -> TestTolerances:
145145
cp_loss_atol=0.5,
146146
cp_loss_rtol=0.25,
147147
)
148+
149+
# ==================== Mixtral-Specific KV-Cache Tests ====================
150+
151+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
152+
"""Create HFInferenceParams for the given config."""
153+
past_key_values = HFInferenceParams(
154+
max_batch_size=batch_size * num_beams,
155+
max_sequence_length=max_seq_len,
156+
num_heads_kv=config.num_key_value_heads,
157+
head_dim_k=config.hidden_size // config.num_attention_heads,
158+
dtype=torch.bfloat16,
159+
qkv_format="thd",
160+
max_ctx_len=max_seq_len,
161+
)
162+
for layer_number in range(1, config.num_hidden_layers + 1):
163+
past_key_values.allocate_memory(layer_number)
164+
return past_key_values
165+
166+
def test_generate_with_cache(self):
167+
"""Test single-prompt generation with KV-cache (THD format)."""
168+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
169+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
170+
model.eval()
171+
172+
tokenizer = self.get_tokenizer()
173+
prompt = "The quick brown fox jumps over"
174+
inputs = tokenizer(prompt, return_tensors="pt")
175+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
176+
177+
past_key_values = self._create_inference_params(config, batch_size=1)
178+
179+
with torch.no_grad():
180+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
181+
182+
# Verify generation produced new tokens
183+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
184+
185+
def test_generate_with_cache_batched(self):
186+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
187+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
188+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
189+
model.eval()
190+
191+
tokenizer = self.get_tokenizer()
192+
prompts = (
193+
"The quick brown fox jumps over the lazy dog.",
194+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
195+
)
196+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
197+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
198+
199+
past_key_values = self._create_inference_params(config, batch_size=2)
200+
201+
with torch.no_grad():
202+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
203+
204+
# Verify generation produced new tokens for both sequences
205+
assert output_ids.shape[0] == 2
206+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
207+
208+
def test_generate_with_cache_beam_search(self):
209+
"""Test batched generation with KV-cache and beam search."""
210+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
211+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
212+
model.eval()
213+
214+
tokenizer = self.get_tokenizer()
215+
prompts = (
216+
"The quick brown fox jumps over the lazy dog.",
217+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
218+
)
219+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
220+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
221+
222+
num_beams = 2
223+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
224+
225+
with torch.no_grad():
226+
output_ids = model.generate(
227+
**inputs,
228+
max_new_tokens=16,
229+
use_cache=True,
230+
past_key_values=past_key_values,
231+
num_beams=num_beams,
232+
do_sample=True,
233+
)
234+
235+
# Verify generation produced new tokens for both sequences
236+
assert output_ids.shape[0] == 2
237+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
238+
239+
# ==================== Standalone Mixtral Generation Tests ====================
240+
241+
def test_te_mixtral_model_generate_with_cache_beam_search(self):
242+
"""Test Mixtral generation with KV-cache and beam search using real model weights."""
243+
import gc
244+
245+
model_hf = self.get_reference_model()
246+
model_te = convert_mixtral_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal")
247+
del model_hf
248+
gc.collect()
249+
250+
model_te.to("cuda")
251+
model_te.eval()
252+
253+
tokenizer = self.get_tokenizer()
254+
255+
prompts = (
256+
'Licensed under the Apache License, Version 2.0 (the "License");'
257+
" you may not use this file except in compliance with the License."
258+
" You may obtain a copy of the License at",
259+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore",
260+
)
261+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
262+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
263+
264+
num_beams = 2
265+
config = model_te.config
266+
past_key_values = HFInferenceParams(
267+
max_batch_size=2 * num_beams,
268+
max_sequence_length=256,
269+
num_heads_kv=config.num_key_value_heads,
270+
head_dim_k=config.hidden_size // config.num_attention_heads,
271+
dtype=torch.bfloat16,
272+
qkv_format="thd",
273+
max_ctx_len=256,
274+
)
275+
for layer_number in range(1, config.num_hidden_layers + 1):
276+
past_key_values.allocate_memory(layer_number)
277+
278+
with torch.no_grad():
279+
output_ids = model_te.generate(
280+
**inputs,
281+
max_new_tokens=16,
282+
use_cache=True,
283+
past_key_values=past_key_values,
284+
num_beams=num_beams,
285+
do_sample=False,
286+
)
287+
288+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
289+
assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0]
290+
assert "et dolore magna aliqua" in generated_text[1]

0 commit comments

Comments
 (0)