Skip to content

Commit 0c90043

Browse files
Your Namerwightman
authored andcommitted
fix: branch Hiera MaskUnitAttention into 4D global path for FlashAttention dispatch
The global attention path in MaskUnitAttention.forward() used a 5D tensor reshape with num_windows=1 as a shortcut. This caused PyTorch SDPA to silently fall back from efficient backends (FlashAttention, Memory-Efficient, CuDNN) to the O(N^2) math backend, as all efficient kernels require 4D contiguous tensors. At high resolutions (e.g. 2048x2048 -> 16384 tokens), the math backend materializes the full N*N attention matrix, causing catastrophic VRAM usage and OOM on consumer GPUs. Changes: - Branch forward() into windowed (5D, unchanged) and global (4D) paths - Global path reshapes directly to [B, N, 3, heads, head_dim] -> 4D QKV - Adjust q_stride pooling dim from amax(dim=3) to amax(dim=2) for global - Add .contiguous() on q, k, v to guarantee FlashAttention compatibility - Split output transpose: transpose(1,3) for windowed, transpose(1,2) for global
1 parent a346c76 commit 0c90043

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

timm/models/hiera.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,31 @@ def __init__(
299299
def forward(self, x: torch.Tensor) -> torch.Tensor:
300300
""" Input should be of shape [batch, tokens, channels]. """
301301
B, N, _ = x.shape
302-
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
303-
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
304-
q, k, v = qkv.unbind(0)
305302

306-
if self.q_stride > 1:
307-
# Refer to Unroll to see how this performs a maxpool-Nd
308-
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
303+
if self.use_mask_unit_attn:
304+
# Windowed attention: 5D path [B, heads, num_windows, tokens_per_window, head_dim]
305+
num_windows = N // (self.q_stride * self.window_size)
306+
qkv = self.qkv(x).reshape(
307+
B, -1, num_windows, 3, self.heads, self.head_dim,
308+
).permute(3, 0, 4, 2, 1, 5)
309+
q, k, v = qkv.unbind(0)
310+
311+
if self.q_stride > 1:
312+
# Refer to Unroll to see how this performs a maxpool-Nd
313+
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
314+
else:
315+
# Global attention: 4D path [B, heads, N, head_dim]
316+
# Avoids the dummy num_windows=1 dimension that prevents FlashAttention dispatch.
317+
qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
318+
q, k, v = qkv.unbind(0)
319+
320+
if self.q_stride > 1:
321+
# dim=2 instead of dim=3 because num_windows dimension is absent
322+
q = q.view(B, self.heads, self.q_stride, -1, self.head_dim).amax(dim=2)
323+
324+
# Enforce contiguous memory layout so SDPA dispatches to FlashAttention
325+
# instead of silently falling back to the O(N^2) math backend.
326+
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
309327

310328
if self.fused_attn:
311329
# Note: the original paper did *not* use SDPA, it's a free boost!
@@ -315,7 +333,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
315333
attn = attn.softmax(dim=-1)
316334
x = attn @ v
317335

318-
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
336+
# Output transpose adapts to 5D (windowed) vs 4D (global) layout
337+
if self.use_mask_unit_attn:
338+
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
339+
else:
340+
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
341+
319342
x = self.proj(x)
320343
return x
321344

0 commit comments

Comments
 (0)