@@ -137,6 +137,125 @@ def test_rot_attention_pool2d_rope_types(self):
137137 out = pool (x )
138138 assert out .shape == (2 , 64 )
139139
140+ @pytest .mark .parametrize ('pool_cls,base_kwargs,input_shape' , [
141+ ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
142+ ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
143+ ])
144+ @pytest .mark .parametrize ('out_features,embed_dim,expected_out' , [
145+ (None , None , 64 ), # default: out_features = in_features
146+ (None , 128 , 64 ), # default with different embed_dim
147+ (32 , None , 32 ), # explicit out_features
148+ (32 , 128 , 32 ), # explicit out_features with different embed_dim
149+ (0 , None , 64 ), # disabled projection, out = embed_dim = in_features
150+ (0 , 128 , 128 ), # disabled projection, out = embed_dim
151+ ])
152+ def test_attention_pool2d_out_features (
153+ self , pool_cls , base_kwargs , input_shape , out_features , embed_dim , expected_out ,
154+ ):
155+ import timm .layers as layers
156+ kwargs = {** base_kwargs , 'out_features' : out_features }
157+ if embed_dim is not None :
158+ kwargs ['embed_dim' ] = embed_dim
159+ pool = getattr (layers , pool_cls )(** kwargs ).to (torch_device )
160+ assert pool .out_features == expected_out
161+ if out_features == 0 :
162+ assert isinstance (pool .proj , nn .Identity )
163+ else :
164+ assert isinstance (pool .proj , nn .Linear )
165+ x = torch .randn (* input_shape , device = torch_device )
166+ out = pool (x )
167+ assert out .shape == (2 , expected_out )
168+
169+ @pytest .mark .parametrize ('pool_cls,base_kwargs,input_shape' , [
170+ ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 , 'embed_dim' : 128 }, (2 , 64 , 7 , 7 )),
171+ ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 , 'embed_dim' : 128 }, (2 , 64 , 7 , 7 )),
172+ ])
173+ @pytest .mark .parametrize ('num_classes,expected_out' , [
174+ (10 , 10 ),
175+ (0 , 128 ), # reset to 0 => Identity, out_features = embed_dim
176+ (100 , 100 ),
177+ ])
178+ def test_attention_pool2d_reset (
179+ self , pool_cls , base_kwargs , input_shape , num_classes , expected_out ,
180+ ):
181+ import timm .layers as layers
182+ pool = getattr (layers , pool_cls )(** base_kwargs ).to (torch_device )
183+ pool .reset (num_classes = num_classes )
184+ assert pool .out_features == expected_out
185+ if num_classes > 0 :
186+ assert isinstance (pool .proj , nn .Linear )
187+ assert pool .proj .in_features == 128 # embed_dim, not in_features
188+ assert pool .proj .out_features == num_classes
189+ else :
190+ assert isinstance (pool .proj , nn .Identity )
191+ x = torch .randn (* input_shape , device = torch_device )
192+ out = pool (x )
193+ assert out .shape == (2 , expected_out )
194+
195+ @pytest .mark .parametrize ('pool_cls,base_kwargs,input_shape' , [
196+ ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
197+ ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
198+ ])
199+ def test_attention_pool2d_pre_logits (self , pool_cls , base_kwargs , input_shape ):
200+ import timm .layers as layers
201+ pool = getattr (layers , pool_cls )(** base_kwargs , out_features = 32 ).to (torch_device )
202+ x = torch .randn (* input_shape , device = torch_device )
203+ out = pool (x , pre_logits = True )
204+ # pre_logits skips proj, so output dim = embed_dim (= in_features by default)
205+ assert out .shape == (2 , 64 )
206+
207+ @pytest .mark .parametrize ('pool_cls,base_kwargs,input_shape' , [
208+ ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
209+ ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
210+ ])
211+ def test_attention_pool2d_qkv_separate (self , pool_cls , base_kwargs , input_shape ):
212+ import timm .layers as layers
213+ pool = getattr (layers , pool_cls )(** base_kwargs , qkv_separate = True ).to (torch_device )
214+ assert pool .qkv is None
215+ x = torch .randn (* input_shape , device = torch_device )
216+ out = pool (x )
217+ assert out .shape == (2 , 64 )
218+
219+ @pytest .mark .parametrize ('pool_cls,base_kwargs,input_shape' , [
220+ ('RotAttentionPool2d' , {'in_features' : 64 , 'ref_feat_size' : 7 }, (2 , 64 , 7 , 7 )),
221+ ('AttentionPool2d' , {'in_features' : 64 , 'feat_size' : 7 }, (2 , 64 , 7 , 7 )),
222+ ])
223+ def test_attention_pool2d_class_token (self , pool_cls , base_kwargs , input_shape ):
224+ import timm .layers as layers
225+ pool = getattr (layers , pool_cls )(** base_kwargs , class_token = True ).to (torch_device )
226+ assert pool .cls_token is not None
227+ x = torch .randn (* input_shape , device = torch_device )
228+ out = pool (x )
229+ assert out .shape == (2 , 64 )
230+
231+ @pytest .mark .parametrize ('out_features,embed_dim,expected_out' , [
232+ (None , None , 64 ), # default: out_features = in_features
233+ (None , 128 , 64 ), # default with different embed_dim
234+ (32 , None , 32 ), # explicit out_features
235+ (32 , 128 , 32 ), # explicit out_features with different embed_dim
236+ (0 , None , 64 ), # disabled projection, out = embed_dim = in_features
237+ (0 , 128 , 128 ), # disabled projection, out = embed_dim
238+ ])
239+ def test_attention_pool_latent_out_features (self , out_features , embed_dim , expected_out ):
240+ from timm .layers import AttentionPoolLatent
241+ kwargs = {'in_features' : 64 , 'num_heads' : 4 }
242+ if out_features is not None :
243+ kwargs ['out_features' ] = out_features
244+ if embed_dim is not None :
245+ kwargs ['embed_dim' ] = embed_dim
246+ pool = AttentionPoolLatent (** kwargs ).to (torch_device )
247+ assert pool .out_features == expected_out
248+ if out_features == 0 :
249+ assert isinstance (pool .proj , nn .Identity )
250+ assert pool .mlp is None
251+ else :
252+ assert isinstance (pool .proj , nn .Linear )
253+ assert pool .mlp is not None
254+ in_dim = embed_dim or 64
255+ x = torch .randn (2 , 49 , in_dim , device = torch_device )
256+ out = pool (x )
257+ assert out .shape == (2 , expected_out )
258+
140259
141260# LSE Pool Tests
142261
0 commit comments