|
37 | 37 |
|
38 | 38 | from collator import DataCollatorWithFlattening |
39 | 39 | from convert import convert_mixtral_hf_to_te, convert_mixtral_te_to_hf |
40 | | -from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM |
| 40 | +from modeling_mixtral_te import HFInferenceParams, NVMixtralConfig, NVMixtralForCausalLM |
41 | 41 | from tests.common import BaseModelTest, TestTolerances |
42 | 42 |
|
43 | 43 |
|
@@ -145,3 +145,146 @@ def get_tolerances(self) -> TestTolerances: |
145 | 145 | cp_loss_atol=0.5, |
146 | 146 | cp_loss_rtol=0.25, |
147 | 147 | ) |
| 148 | + |
| 149 | + # ==================== Mixtral-Specific KV-Cache Tests ==================== |
| 150 | + |
| 151 | + def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): |
| 152 | + """Create HFInferenceParams for the given config.""" |
| 153 | + past_key_values = HFInferenceParams( |
| 154 | + max_batch_size=batch_size * num_beams, |
| 155 | + max_sequence_length=max_seq_len, |
| 156 | + num_heads_kv=config.num_key_value_heads, |
| 157 | + head_dim_k=config.hidden_size // config.num_attention_heads, |
| 158 | + dtype=torch.bfloat16, |
| 159 | + qkv_format="thd", |
| 160 | + max_ctx_len=max_seq_len, |
| 161 | + ) |
| 162 | + for layer_number in range(1, config.num_hidden_layers + 1): |
| 163 | + past_key_values.allocate_memory(layer_number) |
| 164 | + return past_key_values |
| 165 | + |
| 166 | + def test_generate_with_cache(self): |
| 167 | + """Test single-prompt generation with KV-cache (THD format).""" |
| 168 | + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") |
| 169 | + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) |
| 170 | + model.eval() |
| 171 | + |
| 172 | + tokenizer = self.get_tokenizer() |
| 173 | + prompt = "The quick brown fox jumps over" |
| 174 | + inputs = tokenizer(prompt, return_tensors="pt") |
| 175 | + inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| 176 | + |
| 177 | + past_key_values = self._create_inference_params(config, batch_size=1) |
| 178 | + |
| 179 | + with torch.no_grad(): |
| 180 | + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) |
| 181 | + |
| 182 | + # Verify generation produced new tokens |
| 183 | + assert output_ids.shape[1] > inputs["input_ids"].shape[1] |
| 184 | + |
| 185 | + def test_generate_with_cache_batched(self): |
| 186 | + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" |
| 187 | + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") |
| 188 | + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) |
| 189 | + model.eval() |
| 190 | + |
| 191 | + tokenizer = self.get_tokenizer() |
| 192 | + prompts = ( |
| 193 | + "The quick brown fox jumps over the lazy dog.", |
| 194 | + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", |
| 195 | + ) |
| 196 | + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") |
| 197 | + inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| 198 | + |
| 199 | + past_key_values = self._create_inference_params(config, batch_size=2) |
| 200 | + |
| 201 | + with torch.no_grad(): |
| 202 | + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) |
| 203 | + |
| 204 | + # Verify generation produced new tokens for both sequences |
| 205 | + assert output_ids.shape[0] == 2 |
| 206 | + assert output_ids.shape[1] > inputs["input_ids"].shape[1] |
| 207 | + |
| 208 | + def test_generate_with_cache_beam_search(self): |
| 209 | + """Test batched generation with KV-cache and beam search.""" |
| 210 | + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") |
| 211 | + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) |
| 212 | + model.eval() |
| 213 | + |
| 214 | + tokenizer = self.get_tokenizer() |
| 215 | + prompts = ( |
| 216 | + "The quick brown fox jumps over the lazy dog.", |
| 217 | + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", |
| 218 | + ) |
| 219 | + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") |
| 220 | + inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| 221 | + |
| 222 | + num_beams = 2 |
| 223 | + past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) |
| 224 | + |
| 225 | + with torch.no_grad(): |
| 226 | + output_ids = model.generate( |
| 227 | + **inputs, |
| 228 | + max_new_tokens=16, |
| 229 | + use_cache=True, |
| 230 | + past_key_values=past_key_values, |
| 231 | + num_beams=num_beams, |
| 232 | + do_sample=True, |
| 233 | + ) |
| 234 | + |
| 235 | + # Verify generation produced new tokens for both sequences |
| 236 | + assert output_ids.shape[0] == 2 |
| 237 | + assert output_ids.shape[1] > inputs["input_ids"].shape[1] |
| 238 | + |
| 239 | + # ==================== Standalone Mixtral Generation Tests ==================== |
| 240 | + |
| 241 | + def test_te_mixtral_model_generate_with_cache_beam_search(self): |
| 242 | + """Test Mixtral generation with KV-cache and beam search using real model weights.""" |
| 243 | + import gc |
| 244 | + |
| 245 | + model_hf = self.get_reference_model() |
| 246 | + model_te = convert_mixtral_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal") |
| 247 | + del model_hf |
| 248 | + gc.collect() |
| 249 | + |
| 250 | + model_te.to("cuda") |
| 251 | + model_te.eval() |
| 252 | + |
| 253 | + tokenizer = self.get_tokenizer() |
| 254 | + |
| 255 | + prompts = ( |
| 256 | + 'Licensed under the Apache License, Version 2.0 (the "License");' |
| 257 | + " you may not use this file except in compliance with the License." |
| 258 | + " You may obtain a copy of the License at", |
| 259 | + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", |
| 260 | + ) |
| 261 | + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") |
| 262 | + inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| 263 | + |
| 264 | + num_beams = 2 |
| 265 | + config = model_te.config |
| 266 | + past_key_values = HFInferenceParams( |
| 267 | + max_batch_size=2 * num_beams, |
| 268 | + max_sequence_length=256, |
| 269 | + num_heads_kv=config.num_key_value_heads, |
| 270 | + head_dim_k=config.hidden_size // config.num_attention_heads, |
| 271 | + dtype=torch.bfloat16, |
| 272 | + qkv_format="thd", |
| 273 | + max_ctx_len=256, |
| 274 | + ) |
| 275 | + for layer_number in range(1, config.num_hidden_layers + 1): |
| 276 | + past_key_values.allocate_memory(layer_number) |
| 277 | + |
| 278 | + with torch.no_grad(): |
| 279 | + output_ids = model_te.generate( |
| 280 | + **inputs, |
| 281 | + max_new_tokens=16, |
| 282 | + use_cache=True, |
| 283 | + past_key_values=past_key_values, |
| 284 | + num_beams=num_beams, |
| 285 | + do_sample=False, |
| 286 | + ) |
| 287 | + |
| 288 | + generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| 289 | + assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] |
| 290 | + assert "et dolore magna aliqua" in generated_text[1] |
0 commit comments