-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathrecognizer2d.py
More file actions
124 lines (98 loc) · 4.06 KB
/
recognizer2d.py
File metadata and controls
124 lines (98 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from ..registry import RECOGNIZERS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
class Recognizer2D(BaseRecognizer):
"""2D recognizer model framework."""
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
losses = dict()
x = self.extract_feat(imgs)
if hasattr(self, 'debias_head'):
loss_debias = self.debias_head(x, num_segs=num_segs, target=labels.squeeze(), **kwargs)
losses.update(loss_debias)
if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x, labels.squeeze())
x = x.squeeze(2)
num_segs = 1
cls_score = self.cls_head(x, num_segs)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
losses = dict()
x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x)
x = x.squeeze(2)
losses.update(loss_aux)
num_segs = 1
# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
# When using `TSMHead`, shape is [batch_size * num_crops, num_classes]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop/MultiGroupCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
cls_score = self.cls_head(x, num_segs)
assert cls_score.size()[0] % batches == 0
# calculate num_crops automatically
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return cls_score
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
return self._do_test(imgs).cpu().numpy()
def forward_dummy(self, imgs):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
outs = (self.cls_head(x, num_segs), )
return outs
def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
return self._do_test(imgs)
def get_feat(self, imgs, return_score=False):
"""Defines the computation performed at every call when using get_feat
utils."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
x, _ = self.neck(x)
if return_score:
cls_score = self.cls_head(x, num_segs)
assert cls_score.size()[0] % batches == 0
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return x, cls_score
return x