Skip to content

Commit f8c695d

Browse files
committed
Improve 2d and latent attention pool dimension handling. Fix #2682
1 parent a94c10f commit f8c695d

File tree

3 files changed

+158
-12
lines changed

3 files changed

+158
-12
lines changed

tests/test_layers_pool.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,125 @@ def test_rot_attention_pool2d_rope_types(self):
137137
out = pool(x)
138138
assert out.shape == (2, 64)
139139

140+
@pytest.mark.parametrize('pool_cls,base_kwargs,input_shape', [
141+
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
142+
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
143+
])
144+
@pytest.mark.parametrize('out_features,embed_dim,expected_out', [
145+
(None, None, 64), # default: out_features = in_features
146+
(None, 128, 64), # default with different embed_dim
147+
(32, None, 32), # explicit out_features
148+
(32, 128, 32), # explicit out_features with different embed_dim
149+
(0, None, 64), # disabled projection, out = embed_dim = in_features
150+
(0, 128, 128), # disabled projection, out = embed_dim
151+
])
152+
def test_attention_pool2d_out_features(
153+
self, pool_cls, base_kwargs, input_shape, out_features, embed_dim, expected_out,
154+
):
155+
import timm.layers as layers
156+
kwargs = {**base_kwargs, 'out_features': out_features}
157+
if embed_dim is not None:
158+
kwargs['embed_dim'] = embed_dim
159+
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
160+
assert pool.out_features == expected_out
161+
if out_features == 0:
162+
assert isinstance(pool.proj, nn.Identity)
163+
else:
164+
assert isinstance(pool.proj, nn.Linear)
165+
x = torch.randn(*input_shape, device=torch_device)
166+
out = pool(x)
167+
assert out.shape == (2, expected_out)
168+
169+
@pytest.mark.parametrize('pool_cls,base_kwargs,input_shape', [
170+
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7, 'embed_dim': 128}, (2, 64, 7, 7)),
171+
('AttentionPool2d', {'in_features': 64, 'feat_size': 7, 'embed_dim': 128}, (2, 64, 7, 7)),
172+
])
173+
@pytest.mark.parametrize('num_classes,expected_out', [
174+
(10, 10),
175+
(0, 128), # reset to 0 => Identity, out_features = embed_dim
176+
(100, 100),
177+
])
178+
def test_attention_pool2d_reset(
179+
self, pool_cls, base_kwargs, input_shape, num_classes, expected_out,
180+
):
181+
import timm.layers as layers
182+
pool = getattr(layers, pool_cls)(**base_kwargs).to(torch_device)
183+
pool.reset(num_classes=num_classes)
184+
assert pool.out_features == expected_out
185+
if num_classes > 0:
186+
assert isinstance(pool.proj, nn.Linear)
187+
assert pool.proj.in_features == 128 # embed_dim, not in_features
188+
assert pool.proj.out_features == num_classes
189+
else:
190+
assert isinstance(pool.proj, nn.Identity)
191+
x = torch.randn(*input_shape, device=torch_device)
192+
out = pool(x)
193+
assert out.shape == (2, expected_out)
194+
195+
@pytest.mark.parametrize('pool_cls,base_kwargs,input_shape', [
196+
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
197+
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
198+
])
199+
def test_attention_pool2d_pre_logits(self, pool_cls, base_kwargs, input_shape):
200+
import timm.layers as layers
201+
pool = getattr(layers, pool_cls)(**base_kwargs, out_features=32).to(torch_device)
202+
x = torch.randn(*input_shape, device=torch_device)
203+
out = pool(x, pre_logits=True)
204+
# pre_logits skips proj, so output dim = embed_dim (= in_features by default)
205+
assert out.shape == (2, 64)
206+
207+
@pytest.mark.parametrize('pool_cls,base_kwargs,input_shape', [
208+
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
209+
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
210+
])
211+
def test_attention_pool2d_qkv_separate(self, pool_cls, base_kwargs, input_shape):
212+
import timm.layers as layers
213+
pool = getattr(layers, pool_cls)(**base_kwargs, qkv_separate=True).to(torch_device)
214+
assert pool.qkv is None
215+
x = torch.randn(*input_shape, device=torch_device)
216+
out = pool(x)
217+
assert out.shape == (2, 64)
218+
219+
@pytest.mark.parametrize('pool_cls,base_kwargs,input_shape', [
220+
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
221+
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
222+
])
223+
def test_attention_pool2d_class_token(self, pool_cls, base_kwargs, input_shape):
224+
import timm.layers as layers
225+
pool = getattr(layers, pool_cls)(**base_kwargs, class_token=True).to(torch_device)
226+
assert pool.cls_token is not None
227+
x = torch.randn(*input_shape, device=torch_device)
228+
out = pool(x)
229+
assert out.shape == (2, 64)
230+
231+
@pytest.mark.parametrize('out_features,embed_dim,expected_out', [
232+
(None, None, 64), # default: out_features = in_features
233+
(None, 128, 64), # default with different embed_dim
234+
(32, None, 32), # explicit out_features
235+
(32, 128, 32), # explicit out_features with different embed_dim
236+
(0, None, 64), # disabled projection, out = embed_dim = in_features
237+
(0, 128, 128), # disabled projection, out = embed_dim
238+
])
239+
def test_attention_pool_latent_out_features(self, out_features, embed_dim, expected_out):
240+
from timm.layers import AttentionPoolLatent
241+
kwargs = {'in_features': 64, 'num_heads': 4}
242+
if out_features is not None:
243+
kwargs['out_features'] = out_features
244+
if embed_dim is not None:
245+
kwargs['embed_dim'] = embed_dim
246+
pool = AttentionPoolLatent(**kwargs).to(torch_device)
247+
assert pool.out_features == expected_out
248+
if out_features == 0:
249+
assert isinstance(pool.proj, nn.Identity)
250+
assert pool.mlp is None
251+
else:
252+
assert isinstance(pool.proj, nn.Linear)
253+
assert pool.mlp is not None
254+
in_dim = embed_dim or 64
255+
x = torch.randn(2, 49, in_dim, device=torch_device)
256+
out = pool(x)
257+
assert out.shape == (2, expected_out)
258+
140259

141260
# LSE Pool Tests
142261

timm/layers/attention_pool.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
class AttentionPoolLatent(nn.Module):
1414
""" Attention pooling w/ latent query
15+
16+
Setting out_features=0 disables the output projection, norm, and MLP layers (pre_logits mode).
1517
"""
1618
fused_attn: torch.jit.Final[bool]
1719

@@ -38,7 +40,8 @@ def __init__(
3840
dd = {'device': device, 'dtype': dtype}
3941
super().__init__()
4042
embed_dim = embed_dim or in_features
41-
out_features = out_features or in_features
43+
if out_features is None:
44+
out_features = in_features
4245
assert embed_dim % num_heads == 0
4346
self.num_heads = num_heads
4447
self.head_dim = embed_dim // num_heads
@@ -66,11 +69,20 @@ def __init__(
6669
else:
6770
self.q_norm = nn.Identity()
6871
self.k_norm = nn.Identity()
69-
self.proj = nn.Linear(embed_dim, embed_dim, **dd)
70-
self.proj_drop = nn.Dropout(drop)
7172

72-
self.norm = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity()
73-
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer, **dd)
73+
if out_features > 0:
74+
self.proj = nn.Linear(embed_dim, out_features, **dd)
75+
self.proj_drop = nn.Dropout(drop)
76+
self.norm = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity()
77+
self.mlp = Mlp(out_features, int(out_features * mlp_ratio), out_features=out_features, act_layer=act_layer, **dd)
78+
else:
79+
self.proj = nn.Identity()
80+
self.proj_drop = nn.Dropout(drop)
81+
self.norm = nn.Identity()
82+
self.mlp = None
83+
out_features = embed_dim
84+
85+
self.out_features = out_features
7486

7587
self.init_weights()
7688

@@ -106,7 +118,8 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
106118
x = self.proj(x)
107119
x = self.proj_drop(x)
108120

109-
x = x + self.mlp(self.norm(x))
121+
if self.mlp is not None:
122+
x = x + self.mlp(self.norm(x))
110123

111124
# optional pool if latent seq_len > 1 and pooled output is desired
112125
if self.pool == 'token':

timm/layers/attention_pool2d.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class RotAttentionPool2d(nn.Module):
2828
2929
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
3030
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
31+
32+
Setting out_features=0 disables the output projection (pre_logits mode).
3133
"""
3234
fused_attn: torch.jit.Final[bool]
3335

@@ -53,7 +55,12 @@ def __init__(
5355
assert pool_type in ('', 'token')
5456
self.embed_dim = embed_dim = embed_dim or in_features
5557
self.in_features = in_features
56-
self.out_features = out_features or in_features
58+
if out_features is None:
59+
self.out_features = in_features
60+
elif out_features > 0:
61+
self.out_features = out_features
62+
else:
63+
self.out_features = embed_dim # out_features=0 disables projection
5764
ref_feat_size = to_2tuple(ref_feat_size)
5865
if num_heads is not None:
5966
assert embed_dim % num_heads == 0
@@ -81,7 +88,7 @@ def __init__(
8188
else:
8289
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
8390
self.drop = nn.Dropout(drop_rate)
84-
self.proj = nn.Linear(embed_dim, self.out_features, **dd)
91+
self.proj = nn.Linear(embed_dim, self.out_features, **dd) if out_features != 0 else nn.Identity()
8592

8693
self.pos_embed = create_rope_embed(
8794
rope_type=rope_type,
@@ -113,7 +120,7 @@ def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = No
113120
assert pool_type in ('', 'token')
114121
self.pool_type = pool_type
115122
if num_classes is not None:
116-
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
123+
self.proj = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
117124
self.out_features = num_classes if num_classes > 0 else self.embed_dim
118125

119126
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
@@ -172,6 +179,8 @@ class AttentionPool2d(nn.Module):
172179
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
173180
174181
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
182+
183+
Setting out_features=0 disables the output projection (pre_logits mode).
175184
"""
176185
fused_attn: torch.jit.Final[bool]
177186

@@ -196,7 +205,12 @@ def __init__(
196205
assert pool_type in ('', 'token')
197206
self.embed_dim = embed_dim = embed_dim or in_features
198207
self.in_features = in_features
199-
self.out_features = out_features or in_features
208+
if out_features is None:
209+
self.out_features = in_features
210+
elif out_features > 0:
211+
self.out_features = out_features
212+
else:
213+
self.out_features = embed_dim # out_features=0 disables projection
200214
if num_heads is not None:
201215
assert embed_dim % num_heads == 0
202216
head_dim = embed_dim // num_heads
@@ -225,7 +239,7 @@ def __init__(
225239
self.q = self.k = self.v = None
226240
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
227241
self.drop = nn.Dropout(drop_rate)
228-
self.proj = nn.Linear(embed_dim, self.out_features, **dd)
242+
self.proj = nn.Linear(embed_dim, self.out_features, **dd) if out_features != 0 else nn.Identity()
229243
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features, **dd))
230244

231245
self.init_weights()
@@ -251,7 +265,7 @@ def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = No
251265
assert pool_type in ('', 'token')
252266
self.pool_type = pool_type
253267
if num_classes is not None:
254-
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
268+
self.proj = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
255269
self.out_features = num_classes if num_classes > 0 else self.embed_dim
256270

257271
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:

0 commit comments

Comments
 (0)