Skip to content

Commit 212db7b

Browse files
Cosmos Transfer2.5 Auto-Regressive Inference Pipeline (#13114)
* AR * address comments * address comments 2
1 parent 3105848 commit 212db7b

File tree

6 files changed

+452
-261
lines changed

6 files changed

+452
-261
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ output = pipe(
4646
output.save("output.png")
4747
```
4848

49+
## Cosmos2_5_TransferPipeline
50+
51+
[[autodoc]] Cosmos2_5_TransferPipeline
52+
- all
53+
- __call__
54+
55+
56+
## Cosmos2_5_PredictBasePipeline
57+
58+
[[autodoc]] Cosmos2_5_PredictBasePipeline
59+
- all
60+
- __call__
61+
62+
4963
## CosmosTextToWorldPipeline
5064

5165
[[autodoc]] CosmosTextToWorldPipeline
@@ -70,12 +84,6 @@ output.save("output.png")
7084
- all
7185
- __call__
7286

73-
## Cosmos2_5_PredictBasePipeline
74-
75-
[[autodoc]] Cosmos2_5_PredictBasePipeline
76-
- all
77-
- __call__
78-
7987
## CosmosPipelineOutput
8088

8189
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,15 @@
9494
--transformer_type Cosmos-2.5-Transfer-General-2B \
9595
--transformer_ckpt_path $transformer_ckpt_path \
9696
--vae_type wan2.1 \
97-
--output_path converted/transfer/2b/general/depth \
97+
--output_path converted/transfer/2b/general/depth/pipeline \
9898
--save_pipeline
9999
100+
python scripts/convert_cosmos_to_diffusers.py \
101+
--transformer_type Cosmos-2.5-Transfer-General-2B \
102+
--transformer_ckpt_path $transformer_ckpt_path \
103+
--vae_type wan2.1 \
104+
--output_path converted/transfer/2b/general/depth/models
105+
100106
# edge
101107
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
102108
@@ -120,18 +126,30 @@
120126
--transformer_type Cosmos-2.5-Transfer-General-2B \
121127
--transformer_ckpt_path $transformer_ckpt_path \
122128
--vae_type wan2.1 \
123-
--output_path converted/transfer/2b/general/blur \
129+
--output_path converted/transfer/2b/general/blur/pipeline \
124130
--save_pipeline
125131
132+
python scripts/convert_cosmos_to_diffusers.py \
133+
--transformer_type Cosmos-2.5-Transfer-General-2B \
134+
--transformer_ckpt_path $transformer_ckpt_path \
135+
--vae_type wan2.1 \
136+
--output_path converted/transfer/2b/general/blur/models
137+
126138
# seg
127139
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
128140
129141
python scripts/convert_cosmos_to_diffusers.py \
130142
--transformer_type Cosmos-2.5-Transfer-General-2B \
131143
--transformer_ckpt_path $transformer_ckpt_path \
132144
--vae_type wan2.1 \
133-
--output_path converted/transfer/2b/general/seg \
145+
--output_path converted/transfer/2b/general/seg/pipeline \
134146
--save_pipeline
147+
148+
python scripts/convert_cosmos_to_diffusers.py \
149+
--transformer_type Cosmos-2.5-Transfer-General-2B \
150+
--transformer_ckpt_path $transformer_ckpt_path \
151+
--vae_type wan2.1 \
152+
--output_path converted/transfer/2b/general/seg/models
135153
```
136154
"""
137155

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ def forward(
191191
dim=1,
192192
)
193193

194-
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
194+
if condition_mask is not None:
195+
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
196+
else:
197+
control_hidden_states = torch.cat(
198+
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
199+
)
195200

196201
padding_mask_resized = transforms.functional.resize(
197202
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

0 commit comments

Comments
 (0)