Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions timm/models/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,31 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
q, k, v = qkv.unbind(0)

if self.q_stride > 1:
# Refer to Unroll to see how this performs a maxpool-Nd
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
if self.use_mask_unit_attn:
# Windowed attention: 5D path [B, heads, num_windows, tokens_per_window, head_dim]
num_windows = N // (self.q_stride * self.window_size)
qkv = self.qkv(x).reshape(
B, -1, num_windows, 3, self.heads, self.head_dim,
).permute(3, 0, 4, 2, 1, 5)
q, k, v = qkv.unbind(0)

if self.q_stride > 1:
# Refer to Unroll to see how this performs a maxpool-Nd
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
else:
# Global attention: 4D path [B, heads, N, head_dim]
# Avoids the dummy num_windows=1 dimension that prevents FlashAttention dispatch.
qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

if self.q_stride > 1:
# dim=2 instead of dim=3 because num_windows dimension is absent
q = q.view(B, self.heads, self.q_stride, -1, self.head_dim).amax(dim=2)

# Enforce contiguous memory layout so SDPA dispatches to FlashAttention
# instead of silently falling back to the O(N^2) math backend.
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()

if self.fused_attn:
# Note: the original paper did *not* use SDPA, it's a free boost!
Expand All @@ -315,7 +333,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
attn = attn.softmax(dim=-1)
x = attn @ v

x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
# Output transpose adapts to 5D (windowed) vs 4D (global) layout
if self.use_mask_unit_attn:
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
else:
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)

x = self.proj(x)
return x

Expand Down