Skip to content

Commit 682c845

Browse files
committed
Implement PRR as a pooling module. Alternative to #2678
1 parent 3e8def8 commit 682c845

File tree

4 files changed

+123
-7
lines changed

4 files changed

+123
-7
lines changed

tests/test_layers_pool.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,41 @@ def test_attention_pool2d_class_token(self, pool_cls, base_kwargs, input_shape):
228228
out = pool(x)
229229
assert out.shape == (2, 64)
230230

231+
def test_attention_pool_prr_basic(self):
232+
from timm.layers import AttentionPoolPrr
233+
x = torch.randn(2, 50, 64, device=torch_device) # 1 CLS + 49 patches
234+
pool = AttentionPoolPrr(dim=64, num_heads=4).to(torch_device)
235+
out = pool(x)
236+
assert out.shape == (2, 64)
237+
238+
def test_attention_pool_prr_avg_pool(self):
239+
from timm.layers import AttentionPoolPrr
240+
x = torch.randn(2, 49, 64, device=torch_device)
241+
pool = AttentionPoolPrr(dim=64, num_heads=4, pool_type='avg').to(torch_device)
242+
out = pool(x)
243+
assert out.shape == (2, 64)
244+
245+
def test_attention_pool_prr_parameter_free(self):
246+
from timm.layers import AttentionPoolPrr
247+
pool = AttentionPoolPrr(dim=64, num_heads=4)
248+
num_params = sum(p.numel() for p in pool.parameters())
249+
assert num_params == 0, f"Expected 0 parameters, got {num_params}"
250+
251+
def test_attention_pool_prr_with_norms(self):
252+
from timm.layers import AttentionPoolPrr
253+
pool = AttentionPoolPrr(
254+
dim=64,
255+
num_heads=4,
256+
pre_norm=True,
257+
post_norm=True,
258+
).to(torch_device)
259+
# Should have parameters from the two LayerNorms
260+
num_params = sum(p.numel() for p in pool.parameters())
261+
assert num_params > 0
262+
x = torch.randn(2, 49, 64, device=torch_device)
263+
out = pool(x)
264+
assert out.shape == (2, 64)
265+
231266
@pytest.mark.parametrize('out_features,embed_dim,expected_out', [
232267
(None, None, 64), # default: out_features = in_features
233268
(None, 128, 64), # default with different embed_dim
@@ -365,6 +400,7 @@ class TestPoolingCommon:
365400
('SimPool1d', {'dim': 64}, (2, 49, 64)),
366401
('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)),
367402
('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)),
403+
('AttentionPoolPrr', {'dim': 64, 'num_heads': 4}, (2, 49, 64)),
368404
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
369405
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
370406
])
@@ -383,6 +419,7 @@ def test_gradient_flow(self, pool_cls, kwargs, input_shape):
383419
('LsePlus1d', {}, (2, 49, 64)),
384420
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
385421
('SimPool1d', {'dim': 64}, (2, 49, 64)),
422+
('AttentionPoolPrr', {'dim': 64, 'num_heads': 4}, (2, 49, 64)),
386423
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
387424
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
388425
])
@@ -401,6 +438,7 @@ def test_torchscript(self, pool_cls, kwargs, input_shape):
401438
('LsePlus1d', {}, (2, 49, 64)),
402439
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
403440
('SimPool1d', {'dim': 64}, (2, 49, 64)),
441+
('AttentionPoolPrr', {'dim': 64, 'num_heads': 4}, (2, 49, 64)),
404442
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
405443
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
406444
])

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .attention import Attention, AttentionRope, maybe_add_mask
1919
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
20-
from .attention_pool import AttentionPoolLatent
20+
from .attention_pool import AttentionPoolLatent, AttentionPoolPrr
2121
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d
2222
from .blur_pool import BlurPool2d, create_aa
2323
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead

timm/layers/attention_pool.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,70 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
126126
x = x[:, 0]
127127
elif self.pool == 'avg':
128128
x = x.mean(1)
129+
return x
130+
131+
132+
class AttentionPoolPrr(nn.Module):
133+
""" Patch Representation Refinement (PRR) attention pool.
134+
135+
From "Locality-Attending Vision Transformer" (ICLR 2026).
136+
137+
Parameter-free multi-head self-attention that refines all patch representations
138+
before pooling. No Q/K/V projections — input is reshaped directly into multi-head
139+
format for self-attention.
140+
"""
141+
fused_attn: torch.jit.Final[bool]
142+
143+
def __init__(
144+
self,
145+
dim: int,
146+
num_heads: int = 8,
147+
pool_type: str = 'token',
148+
pre_norm: bool = False,
149+
post_norm: bool = False,
150+
norm_layer: Optional[Type[nn.Module]] = None,
151+
device=None,
152+
dtype=None,
153+
):
154+
dd = {'device': device, 'dtype': dtype}
155+
super().__init__()
156+
assert pool_type in ('token', 'avg'), f"pool_type must be 'token' or 'avg', got '{pool_type}'"
157+
assert dim % num_heads == 0, f"dim ({dim}) must be divisible by num_heads ({num_heads})"
158+
159+
if norm_layer is None and (pre_norm or post_norm):
160+
norm_layer = nn.LayerNorm
161+
162+
self.num_heads = num_heads
163+
self.head_dim = dim // num_heads
164+
self.scale = self.head_dim ** -0.5
165+
self.pool = pool_type
166+
self.fused_attn = use_fused_attn()
167+
self.out_features = dim
168+
169+
self.pre_norm = norm_layer(dim, **dd) if pre_norm else nn.Identity()
170+
self.post_norm = norm_layer(dim, **dd) if post_norm else nn.Identity()
171+
172+
def forward(self, x: torch.Tensor) -> torch.Tensor:
173+
B, N, C = x.shape
174+
175+
x = self.pre_norm(x)
176+
177+
# Parameter-free self-attention: reshape into multi-head format
178+
qkv = x.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, N, D)
179+
if self.fused_attn:
180+
x = F.scaled_dot_product_attention(qkv, qkv, qkv)
181+
else:
182+
attn = (qkv * self.scale) @ qkv.transpose(-2, -1)
183+
attn = attn.softmax(dim=-1)
184+
x = attn @ qkv
185+
x = x.transpose(1, 2).reshape(B, N, C)
186+
187+
x = self.post_norm(x)
188+
189+
# Pool
190+
if self.pool == 'token':
191+
x = x[:, 0]
192+
elif self.pool == 'avg':
193+
x = x.mean(1)
194+
129195
return x

timm/models/vision_transformer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Attention,
5050
DiffAttention,
5151
AttentionPoolLatent,
52+
AttentionPoolPrr,
5253
PatchEmbed,
5354
Mlp,
5455
SwiGLUPacked,
@@ -692,7 +693,7 @@ def __init__(
692693
patch_size: Union[int, Tuple[int, int]] = 16,
693694
in_chans: int = 3,
694695
num_classes: int = 1000,
695-
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
696+
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map', 'prr'] = 'token',
696697
embed_dim: int = 768,
697698
depth: int = 12,
698699
num_heads: int = 12,
@@ -764,7 +765,7 @@ def __init__(
764765
"""
765766
super().__init__()
766767
dd = {'device': device, 'dtype': dtype}
767-
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
768+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
768769
assert class_token or global_pool != 'token'
769770
assert pos_embed in ('', 'none', 'learn')
770771
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
@@ -858,6 +859,15 @@ def __init__(
858859
act_layer=act_layer,
859860
**dd,
860861
)
862+
elif global_pool == 'prr':
863+
self.attn_pool = AttentionPoolPrr(
864+
self.embed_dim,
865+
num_heads=num_heads,
866+
pool_type='token' if class_token else 'avg',
867+
norm_layer=norm_layer,
868+
**dd,
869+
)
870+
self.pool_include_prefix = True
861871
else:
862872
self.attn_pool = None
863873
self.fc_norm = norm_layer(embed_dim, **dd) if final_norm and use_fc_norm else nn.Identity()
@@ -961,11 +971,13 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None)
961971
"""
962972
self.num_classes = num_classes
963973
if global_pool is not None:
964-
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
965-
if global_pool == 'map' and self.attn_pool is None:
974+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map', 'prr')
975+
if global_pool in ('map', 'prr') and self.attn_pool is None:
966976
assert False, "Cannot currently add attention pooling in reset_classifier()."
967-
elif global_pool != 'map' and self.attn_pool is not None:
977+
elif global_pool not in ('map', 'prr') and self.attn_pool is not None:
968978
self.attn_pool = None # remove attention pooling
979+
elif global_pool in ('map', 'prr') and self.global_pool != global_pool:
980+
assert False, "Cannot currently change attention pooling type in reset_classifier()."
969981
self.global_pool = global_pool
970982
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
971983

@@ -1476,7 +1488,7 @@ def _n2p(_w, t=True, idx=None):
14761488
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
14771489
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
14781490
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
1479-
if model.attn_pool is not None:
1491+
if isinstance(model.attn_pool, AttentionPoolLatent):
14801492
block_prefix = f'{prefix}MAPHead_0/'
14811493
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
14821494
model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))

0 commit comments

Comments
 (0)