Commit 0c90043
fix: branch Hiera MaskUnitAttention into 4D global path for FlashAttention dispatch
The global attention path in MaskUnitAttention.forward() used a 5D tensor
reshape with num_windows=1 as a shortcut. This caused PyTorch SDPA to
silently fall back from efficient backends (FlashAttention, Memory-Efficient,
CuDNN) to the O(N^2) math backend, as all efficient kernels require 4D
contiguous tensors.
At high resolutions (e.g. 2048x2048 -> 16384 tokens), the math backend
materializes the full N*N attention matrix, causing catastrophic VRAM usage
and OOM on consumer GPUs.
Changes:
- Branch forward() into windowed (5D, unchanged) and global (4D) paths
- Global path reshapes directly to [B, N, 3, heads, head_dim] -> 4D QKV
- Adjust q_stride pooling dim from amax(dim=3) to amax(dim=2) for global
- Add .contiguous() on q, k, v to guarantee FlashAttention compatibility
- Split output transpose: transpose(1,3) for windowed, transpose(1,2) for global1 parent a346c76 commit 0c90043
1 file changed
+30
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
299 | 299 | | |
300 | 300 | | |
301 | 301 | | |
302 | | - | |
303 | | - | |
304 | | - | |
305 | 302 | | |
306 | | - | |
307 | | - | |
308 | | - | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
309 | 327 | | |
310 | 328 | | |
311 | 329 | | |
| |||
315 | 333 | | |
316 | 334 | | |
317 | 335 | | |
318 | | - | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
319 | 342 | | |
320 | 343 | | |
321 | 344 | | |
| |||
0 commit comments