|
24 | 24 | LTX2ImageToVideoPipeline, |
25 | 25 | LTX2VideoTransformer3DModel, |
26 | 26 | ) |
27 | | -from diffusers.pipelines.ltx2 import LTX2TextConnectors |
| 27 | +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors |
| 28 | +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel |
28 | 29 | from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder |
29 | 30 |
|
30 | 31 | from ...testing_utils import enable_full_determinism |
@@ -174,6 +175,15 @@ def get_dummy_components(self): |
174 | 175 |
|
175 | 176 | return components |
176 | 177 |
|
| 178 | + def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1): |
| 179 | + upsampler = LTX2LatentUpsamplerModel( |
| 180 | + in_channels=in_channels, |
| 181 | + mid_channels=mid_channels, |
| 182 | + num_blocks_per_stage=num_blocks_per_stage, |
| 183 | + ) |
| 184 | + |
| 185 | + return upsampler |
| 186 | + |
177 | 187 | def get_dummy_inputs(self, device, seed=0): |
178 | 188 | if str(device).startswith("mps"): |
179 | 189 | generator = torch.manual_seed(seed) |
@@ -287,5 +297,60 @@ def test_two_stages_inference(self): |
287 | 297 | assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) |
288 | 298 | assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) |
289 | 299 |
|
| 300 | + def test_two_stages_inference_with_upsampler(self): |
| 301 | + device = "cpu" |
| 302 | + |
| 303 | + components = self.get_dummy_components() |
| 304 | + pipe = self.pipeline_class(**components) |
| 305 | + pipe.to(device) |
| 306 | + pipe.set_progress_bar_config(disable=None) |
| 307 | + |
| 308 | + inputs = self.get_dummy_inputs(device) |
| 309 | + inputs["output_type"] = "latent" |
| 310 | + first_stage_output = pipe(**inputs) |
| 311 | + video_latent = first_stage_output.frames |
| 312 | + audio_latent = first_stage_output.audio |
| 313 | + |
| 314 | + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) |
| 315 | + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) |
| 316 | + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) |
| 317 | + |
| 318 | + upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1]) |
| 319 | + upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler) |
| 320 | + upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0] |
| 321 | + self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32)) |
| 322 | + |
| 323 | + inputs["latents"] = upscaled_video_latent |
| 324 | + inputs["audio_latents"] = audio_latent |
| 325 | + inputs["output_type"] = "pt" |
| 326 | + second_stage_output = pipe(**inputs) |
| 327 | + video = second_stage_output.frames |
| 328 | + audio = second_stage_output.audio |
| 329 | + |
| 330 | + self.assertEqual(video.shape, (1, 5, 3, 64, 64)) |
| 331 | + self.assertEqual(audio.shape[0], 1) |
| 332 | + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) |
| 333 | + |
| 334 | + # fmt: off |
| 335 | + expected_video_slice = torch.tensor( |
| 336 | + [ |
| 337 | + 0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079 |
| 338 | + ] |
| 339 | + ) |
| 340 | + expected_audio_slice = torch.tensor( |
| 341 | + [ |
| 342 | + 0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 |
| 343 | + ] |
| 344 | + ) |
| 345 | + # fmt: on |
| 346 | + |
| 347 | + video = video.flatten() |
| 348 | + audio = audio.flatten() |
| 349 | + generated_video_slice = torch.cat([video[:8], video[-8:]]) |
| 350 | + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) |
| 351 | + |
| 352 | + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) |
| 353 | + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) |
| 354 | + |
290 | 355 | def test_inference_batch_single_identical(self): |
291 | 356 | self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) |
0 commit comments