44
55
66def test_alignnet_forward_shapes_cpu ():
7- model = AlignNet (backbone_name = "resnet18" , backbone_weights = None , use_vector_input = False , output_dim = 7 )
7+ model = AlignNet (
8+ backbone_name = "resnet18" ,
9+ backbone_weights = None ,
10+ use_vector_input = False ,
11+ output_dim = 7 ,
12+ )
813 model .eval ()
914 x = torch .randn (2 , 3 , 3 , 64 , 64 ) # B=2, N=3 views
1015 with torch .no_grad ():
@@ -13,7 +18,12 @@ def test_alignnet_forward_shapes_cpu():
1318
1419
1520def test_alignnet_with_vector_input ():
16- model = AlignNet (backbone_name = "resnet18" , backbone_weights = None , use_vector_input = True , output_dim = 7 )
21+ model = AlignNet (
22+ backbone_name = "resnet18" ,
23+ backbone_weights = None ,
24+ use_vector_input = True ,
25+ output_dim = 7 ,
26+ )
1727 model .eval ()
1828 x = torch .randn (1 , 2 , 3 , 64 , 64 )
1929 vecs = [torch .randn (5 )]
@@ -23,12 +33,18 @@ def test_alignnet_with_vector_input():
2333
2434
2535def test_alignnet_performance ():
26- model = AlignNet (backbone_name = "efficientnet_b0" , backbone_weights = None , use_vector_input = True , output_dim = 7 )
36+ model = AlignNet (
37+ backbone_name = "efficientnet_b0" ,
38+ backbone_weights = None ,
39+ use_vector_input = True ,
40+ output_dim = 7 ,
41+ )
2742 model .eval ()
2843 x = torch .randn (1 , 3 , 3 , 224 , 224 ) # B=1, N=3 views
2944 vecs = [torch .randn (5 )]
3045 with torch .no_grad ():
3146 import time
47+
3248 start_time = time .time ()
3349 for _ in range (10 ):
3450 y = model (x , vecs )
0 commit comments