Skip to content

Commit 770a149

Browse files
authored
Merge branch 'main' into modular-index-tests
2 parents 94457fd + 40e9645 commit 770a149

File tree

11 files changed

+630
-334
lines changed

11 files changed

+630
-334
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ output = pipe(
4646
output.save("output.png")
4747
```
4848

49+
## Cosmos2_5_TransferPipeline
50+
51+
[[autodoc]] Cosmos2_5_TransferPipeline
52+
- all
53+
- __call__
54+
55+
56+
## Cosmos2_5_PredictBasePipeline
57+
58+
[[autodoc]] Cosmos2_5_PredictBasePipeline
59+
- all
60+
- __call__
61+
62+
4963
## CosmosTextToWorldPipeline
5064

5165
[[autodoc]] CosmosTextToWorldPipeline
@@ -70,12 +84,6 @@ output.save("output.png")
7084
- all
7185
- __call__
7286

73-
## Cosmos2_5_PredictBasePipeline
74-
75-
[[autodoc]] Cosmos2_5_PredictBasePipeline
76-
- all
77-
- __call__
78-
7987
## CosmosPipelineOutput
8088

8189
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

docs/source/en/training/distributed_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ if __name__ == "__main__":
111111
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
112112

113113
```bash
114-
torchrun run_distributed.py --nproc_per_node=2
114+
torchrun --nproc_per_node=2 run_distributed.py
115115
```
116116

117117
## device_map

scripts/convert_cosmos_to_diffusers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,15 @@
9494
--transformer_type Cosmos-2.5-Transfer-General-2B \
9595
--transformer_ckpt_path $transformer_ckpt_path \
9696
--vae_type wan2.1 \
97-
--output_path converted/transfer/2b/general/depth \
97+
--output_path converted/transfer/2b/general/depth/pipeline \
9898
--save_pipeline
9999
100+
python scripts/convert_cosmos_to_diffusers.py \
101+
--transformer_type Cosmos-2.5-Transfer-General-2B \
102+
--transformer_ckpt_path $transformer_ckpt_path \
103+
--vae_type wan2.1 \
104+
--output_path converted/transfer/2b/general/depth/models
105+
100106
# edge
101107
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
102108
@@ -120,18 +126,30 @@
120126
--transformer_type Cosmos-2.5-Transfer-General-2B \
121127
--transformer_ckpt_path $transformer_ckpt_path \
122128
--vae_type wan2.1 \
123-
--output_path converted/transfer/2b/general/blur \
129+
--output_path converted/transfer/2b/general/blur/pipeline \
124130
--save_pipeline
125131
132+
python scripts/convert_cosmos_to_diffusers.py \
133+
--transformer_type Cosmos-2.5-Transfer-General-2B \
134+
--transformer_ckpt_path $transformer_ckpt_path \
135+
--vae_type wan2.1 \
136+
--output_path converted/transfer/2b/general/blur/models
137+
126138
# seg
127139
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
128140
129141
python scripts/convert_cosmos_to_diffusers.py \
130142
--transformer_type Cosmos-2.5-Transfer-General-2B \
131143
--transformer_ckpt_path $transformer_ckpt_path \
132144
--vae_type wan2.1 \
133-
--output_path converted/transfer/2b/general/seg \
145+
--output_path converted/transfer/2b/general/seg/pipeline \
134146
--save_pipeline
147+
148+
python scripts/convert_cosmos_to_diffusers.py \
149+
--transformer_type Cosmos-2.5-Transfer-General-2B \
150+
--transformer_ckpt_path $transformer_ckpt_path \
151+
--vae_type wan2.1 \
152+
--output_path converted/transfer/2b/general/seg/models
135153
```
136154
"""
137155

src/diffusers/models/attention_dispatch.py

Lines changed: 101 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ class _HubKernelConfig:
329329
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
330330
# TODO: temporary revision for now. Remove when merged upstream into `main`.
331331
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
332-
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
332+
repo_id="kernels-community/flash-attn3",
333+
function_attr="flash_attn_func",
334+
revision="fake-ops-return-probs",
335+
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
336+
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
333337
),
334338
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
335339
repo_id="kernels-community/flash-attn3",
@@ -729,7 +733,7 @@ def _wrapped_flash_attn_3(
729733
) -> tuple[torch.Tensor, torch.Tensor]:
730734
# Hardcoded for now because pytorch does not support tuple/int type hints
731735
window_size = (-1, -1)
732-
out, lse, *_ = flash_attn_3_func(
736+
result = flash_attn_3_func(
733737
q=q,
734738
k=k,
735739
v=v,
@@ -746,7 +750,9 @@ def _wrapped_flash_attn_3(
746750
pack_gqa=pack_gqa,
747751
deterministic=deterministic,
748752
sm_margin=sm_margin,
753+
return_attn_probs=True,
749754
)
755+
out, lse, *_ = result
750756
lse = lse.permute(0, 2, 1)
751757
return out, lse
752758

@@ -1290,36 +1296,62 @@ def _flash_attention_3_hub_forward_op(
12901296
if enable_gqa:
12911297
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
12921298

1293-
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
1294-
out = func(
1295-
q=query,
1296-
k=key,
1297-
v=value,
1298-
softmax_scale=scale,
1299+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
1300+
wrapped_forward_fn = config.wrapped_forward_fn
1301+
if wrapped_forward_fn is None:
1302+
raise RuntimeError(
1303+
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
1304+
"for context parallel execution."
1305+
)
1306+
1307+
if scale is None:
1308+
scale = query.shape[-1] ** (-0.5)
1309+
1310+
out, softmax_lse, *_ = wrapped_forward_fn(
1311+
query,
1312+
key,
1313+
value,
1314+
None,
1315+
None, # k_new, v_new
1316+
None, # qv
1317+
None, # out
1318+
None,
1319+
None,
1320+
None, # cu_seqlens_q/k/k_new
1321+
None,
1322+
None, # seqused_q/k
1323+
None,
1324+
None, # max_seqlen_q/k
1325+
None,
1326+
None,
1327+
None, # page_table, kv_batch_idx, leftpad_k
1328+
None,
1329+
None,
1330+
None, # rotary_cos/sin, seqlens_rotary
1331+
None,
1332+
None,
1333+
None, # q_descale, k_descale, v_descale
1334+
scale,
12991335
causal=is_causal,
1300-
qv=None,
1301-
q_descale=None,
1302-
k_descale=None,
1303-
v_descale=None,
1304-
window_size=window_size,
1336+
window_size_left=window_size[0],
1337+
window_size_right=window_size[1],
1338+
attention_chunk=0,
13051339
softcap=softcap,
13061340
num_splits=num_splits,
13071341
pack_gqa=pack_gqa,
1308-
deterministic=deterministic,
13091342
sm_margin=sm_margin,
1310-
return_attn_probs=return_lse,
13111343
)
13121344

1313-
lse = None
1314-
if return_lse:
1315-
out, lse = out
1316-
lse = lse.permute(0, 2, 1).contiguous()
1345+
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
13171346

13181347
if _save_ctx:
1319-
ctx.save_for_backward(query, key, value)
1348+
ctx.save_for_backward(query, key, value, out, softmax_lse)
13201349
ctx.scale = scale
13211350
ctx.is_causal = is_causal
1322-
ctx._hub_kernel = func
1351+
ctx.window_size = window_size
1352+
ctx.softcap = softcap
1353+
ctx.deterministic = deterministic
1354+
ctx.sm_margin = sm_margin
13231355

13241356
return (out, lse) if return_lse else out
13251357

@@ -1328,55 +1360,50 @@ def _flash_attention_3_hub_backward_op(
13281360
ctx: torch.autograd.function.FunctionCtx,
13291361
grad_out: torch.Tensor,
13301362
*args,
1331-
window_size: tuple[int, int] = (-1, -1),
1332-
softcap: float = 0.0,
1333-
num_splits: int = 1,
1334-
pack_gqa: bool | None = None,
1335-
deterministic: bool = False,
1336-
sm_margin: int = 0,
1363+
**kwargs,
13371364
):
1338-
query, key, value = ctx.saved_tensors
1339-
kernel_fn = ctx._hub_kernel
1340-
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1341-
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1342-
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1343-
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1344-
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1345-
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
1346-
with torch.enable_grad():
1347-
query_r = query.detach().requires_grad_(True)
1348-
key_r = key.detach().requires_grad_(True)
1349-
value_r = value.detach().requires_grad_(True)
1350-
1351-
out = kernel_fn(
1352-
q=query_r,
1353-
k=key_r,
1354-
v=value_r,
1355-
softmax_scale=ctx.scale,
1356-
causal=ctx.is_causal,
1357-
qv=None,
1358-
q_descale=None,
1359-
k_descale=None,
1360-
v_descale=None,
1361-
window_size=window_size,
1362-
softcap=softcap,
1363-
num_splits=num_splits,
1364-
pack_gqa=pack_gqa,
1365-
deterministic=deterministic,
1366-
sm_margin=sm_margin,
1367-
return_attn_probs=False,
1368-
)
1369-
if isinstance(out, tuple):
1370-
out = out[0]
1371-
1372-
grad_query, grad_key, grad_value = torch.autograd.grad(
1373-
out,
1374-
(query_r, key_r, value_r),
1375-
grad_out,
1376-
retain_graph=False,
1377-
allow_unused=False,
1365+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
1366+
wrapped_backward_fn = config.wrapped_backward_fn
1367+
if wrapped_backward_fn is None:
1368+
raise RuntimeError(
1369+
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
1370+
"for context parallel execution."
13781371
)
13791372

1373+
query, key, value, out, softmax_lse = ctx.saved_tensors
1374+
grad_query = torch.empty_like(query)
1375+
grad_key = torch.empty_like(key)
1376+
grad_value = torch.empty_like(value)
1377+
1378+
wrapped_backward_fn(
1379+
grad_out,
1380+
query,
1381+
key,
1382+
value,
1383+
out,
1384+
softmax_lse,
1385+
None,
1386+
None, # cu_seqlens_q, cu_seqlens_k
1387+
None,
1388+
None, # seqused_q, seqused_k
1389+
None,
1390+
None, # max_seqlen_q, max_seqlen_k
1391+
grad_query,
1392+
grad_key,
1393+
grad_value,
1394+
ctx.scale,
1395+
ctx.is_causal,
1396+
ctx.window_size[0],
1397+
ctx.window_size[1],
1398+
ctx.softcap,
1399+
ctx.deterministic,
1400+
ctx.sm_margin,
1401+
)
1402+
1403+
grad_query = grad_query[..., : grad_out.shape[-1]]
1404+
grad_key = grad_key[..., : grad_out.shape[-1]]
1405+
grad_value = grad_value[..., : grad_out.shape[-1]]
1406+
13801407
return grad_query, grad_key, grad_value
13811408

13821409

@@ -2676,7 +2703,7 @@ def _flash_varlen_attention_3(
26762703
key_packed = torch.cat(key_valid, dim=0)
26772704
value_packed = torch.cat(value_valid, dim=0)
26782705

2679-
out, lse, *_ = flash_attn_3_varlen_func(
2706+
result = flash_attn_3_varlen_func(
26802707
q=query_packed,
26812708
k=key_packed,
26822709
v=value_packed,
@@ -2686,7 +2713,13 @@ def _flash_varlen_attention_3(
26862713
max_seqlen_k=max_seqlen_k,
26872714
softmax_scale=scale,
26882715
causal=is_causal,
2716+
return_attn_probs=return_lse,
26892717
)
2718+
if isinstance(result, tuple):
2719+
out, lse, *_ = result
2720+
else:
2721+
out = result
2722+
lse = None
26902723
out = out.unflatten(0, (batch_size, -1))
26912724

26922725
return (out, lse) if return_lse else out

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ def forward(
191191
dim=1,
192192
)
193193

194-
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
194+
if condition_mask is not None:
195+
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
196+
else:
197+
control_hidden_states = torch.cat(
198+
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
199+
)
195200

196201
padding_mask_resized = transforms.functional.resize(
197202
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

0 commit comments

Comments
 (0)