Skip to content

Commit 40e9645

Browse files
Fix LTX-2 image-to-video generation failure in two stages generation (#13187)
* Fix LTX-2 image-to-video generation failure in two stages generation In LTX-2's two-stage image-to-video generation task, specifically after the upsampling step, a shape mismatch occurs between the `latents` and the `conditioning_mask`, which causes an error in function `_create_noised_state`. Fix it by creating the `conditioning_mask` based on the shape of the `latents`. * Add unit test for LTX-2 i2v two stages inference with upsampler * Downscaling the upsampler in LTX-2 image-to-video unit test * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 47455bd commit 40e9645

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,13 @@ def prepare_latents(
699699
mask_shape = (batch_size, 1, num_frames, height, width)
700700

701701
if latents is not None:
702-
conditioning_mask = latents.new_zeros(mask_shape)
703-
conditioning_mask[:, :, 0] = 1.0
704702
if latents.ndim == 5:
703+
# conditioning_mask needs to the same shape as latents in two stages generation.
704+
batch_size, _, num_frames, height, width = latents.shape
705+
mask_shape = (batch_size, 1, num_frames, height, width)
706+
conditioning_mask = latents.new_zeros(mask_shape)
707+
conditioning_mask[:, :, 0] = 1.0
708+
705709
latents = self._normalize_latents(
706710
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
707711
)
@@ -710,6 +714,9 @@ def prepare_latents(
710714
latents = self._pack_latents(
711715
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
712716
)
717+
else:
718+
conditioning_mask = latents.new_zeros(mask_shape)
719+
conditioning_mask[:, :, 0] = 1.0
713720
conditioning_mask = self._pack_latents(
714721
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
715722
).squeeze(-1)

tests/pipelines/ltx2/test_ltx2_image2video.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
LTX2ImageToVideoPipeline,
2525
LTX2VideoTransformer3DModel,
2626
)
27-
from diffusers.pipelines.ltx2 import LTX2TextConnectors
27+
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
28+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
2829
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
2930

3031
from ...testing_utils import enable_full_determinism
@@ -174,6 +175,15 @@ def get_dummy_components(self):
174175

175176
return components
176177

178+
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
179+
upsampler = LTX2LatentUpsamplerModel(
180+
in_channels=in_channels,
181+
mid_channels=mid_channels,
182+
num_blocks_per_stage=num_blocks_per_stage,
183+
)
184+
185+
return upsampler
186+
177187
def get_dummy_inputs(self, device, seed=0):
178188
if str(device).startswith("mps"):
179189
generator = torch.manual_seed(seed)
@@ -287,5 +297,60 @@ def test_two_stages_inference(self):
287297
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
288298
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
289299

300+
def test_two_stages_inference_with_upsampler(self):
301+
device = "cpu"
302+
303+
components = self.get_dummy_components()
304+
pipe = self.pipeline_class(**components)
305+
pipe.to(device)
306+
pipe.set_progress_bar_config(disable=None)
307+
308+
inputs = self.get_dummy_inputs(device)
309+
inputs["output_type"] = "latent"
310+
first_stage_output = pipe(**inputs)
311+
video_latent = first_stage_output.frames
312+
audio_latent = first_stage_output.audio
313+
314+
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
315+
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
316+
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
317+
318+
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
319+
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
320+
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
321+
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
322+
323+
inputs["latents"] = upscaled_video_latent
324+
inputs["audio_latents"] = audio_latent
325+
inputs["output_type"] = "pt"
326+
second_stage_output = pipe(**inputs)
327+
video = second_stage_output.frames
328+
audio = second_stage_output.audio
329+
330+
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
331+
self.assertEqual(audio.shape[0], 1)
332+
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
333+
334+
# fmt: off
335+
expected_video_slice = torch.tensor(
336+
[
337+
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
338+
]
339+
)
340+
expected_audio_slice = torch.tensor(
341+
[
342+
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
343+
]
344+
)
345+
# fmt: on
346+
347+
video = video.flatten()
348+
audio = audio.flatten()
349+
generated_video_slice = torch.cat([video[:8], video[-8:]])
350+
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
351+
352+
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
353+
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
354+
290355
def test_inference_batch_single_identical(self):
291356
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)

0 commit comments

Comments
 (0)