Skip to content

Commit 99f777c

Browse files
authored
update pytorch models following DeepLabCut 3.0.0rc13 (#151)
* DEKRPredictor: add non-maximum suppression (NMS) This commit Updates the DEKR predictor to follow the DeepLabCut implementation in version 3.0.0rc7, see DeepLabCut/DeepLabCut#2907 * DEKRPredictor: speed up with vectorized operations This commit updates the DEKRPredictor to follow the DeepLabCut implementation in version 3.0.0rc13. see DeepLabCut/DeepLabCut#3121 * PartAffinityFieldPredictor (PAF): Speed up cost computation This commit updates the PAF predictor to follow the DeepLabCut implementation in version 3.0.0.rc13. See DeepLabCut/DeepLabCut#3117 * HeatmapPredictor (single animal): speed up with vecorized operations This commit updates the `HeatmapPredictor` in single_predictor.py to follow the implementation in DeepLabCut 3.0.0rc13. See DeepLabCut/DeepLabCut#3110
1 parent 595c295 commit 99f777c

File tree

3 files changed

+436
-177
lines changed

3 files changed

+436
-177
lines changed

dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py

Lines changed: 114 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(
7575
use_heatmap: bool = True,
7676
keypoint_score_type: str = "combined",
7777
max_absorb_distance: int = 75,
78+
nms_threshold: float = 0.05,
79+
apply_pose_nms: bool = True,
7880
):
7981
"""
8082
Args:
@@ -88,6 +90,8 @@ def __init__(
8890
applies the heatmap score to each keypoint. "center" applies the score
8991
of the center of each individual to all of its keypoints. "combined"
9092
multiplies the score of the heatmap and individual for each keypoint.
93+
nms_threshold: Threshold for NMS of pose.
94+
apply_pose_nms: Whether to apply pose NMS
9195
"""
9296
super().__init__()
9397
self.num_animals = num_animals
@@ -99,8 +103,9 @@ def __init__(
99103
if self.keypoint_score_type not in ("heatmap", "center", "combined"):
100104
raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}")
101105

102-
# TODO: Set as in HRNet/DEKR configs. Define as a constant.
103106
self.max_absorb_distance = max_absorb_distance
107+
self.nms_threshold = nms_threshold
108+
self.apply_pose_nms = apply_pose_nms
104109

105110
def forward(
106111
self, stride: float, outputs: dict[str, torch.Tensor]
@@ -134,12 +139,9 @@ def forward(
134139
pose_ind, ctr_scores = self.get_top_values(center_heatmaps)
135140

136141
posemap = posemap.permute(0, 2, 3, 1).view(batch_size, h * w, -1, 2)
137-
poses = torch.zeros(batch_size, pose_ind.shape[1], num_joints, 2).to(
138-
ctr_scores.device
139-
)
140-
for i in range(batch_size):
141-
pose = posemap[i, pose_ind[i]]
142-
poses[i] = pose
142+
143+
batch_indices = torch.arange(batch_size, device=pose_ind.device)[:, None]
144+
poses = posemap[batch_indices, pose_ind]
143145

144146
if self.use_heatmap:
145147
poses = self._update_pose_with_heatmaps(poses, heatmaps[:, :-1])
@@ -174,7 +176,9 @@ def forward(
174176
score = torch.clip(score, min=0, max=1)
175177

176178
poses_w_scores = torch.cat([poses, score], dim=3)
177-
# self.pose_nms(heatmaps, poses_w_scores)
179+
if self.apply_pose_nms:
180+
poses_w_scores = self.pose_nms(poses_w_scores)
181+
178182
return {"poses": poses_w_scores}
179183

180184
def get_locations(
@@ -263,7 +267,7 @@ def max_pool(self, heatmap: torch.Tensor) -> torch.Tensor:
263267
# Assuming you have 'heatmap' tensor
264268
max_pooled_heatmap = predictor.max_pool(heatmap)
265269
"""
266-
pool1 = torch.nn.MaxPool2d(3, 1, 1) # TODO JR 01/2026: Are these unused variables informative?
270+
pool1 = torch.nn.MaxPool2d(3, 1, 1)
267271
pool2 = torch.nn.MaxPool2d(5, 1, 2)
268272
pool3 = torch.nn.MaxPool2d(7, 1, 3)
269273
map_size = (heatmap.shape[1] + heatmap.shape[2]) / 2.0
@@ -299,11 +303,11 @@ def get_top_values(
299303

300304
return pos_ind, scores
301305

302-
########## WIP to take heatmap into account for scoring ##########
303306
def _update_pose_with_heatmaps(
304307
self, _poses: torch.Tensor, kpt_heatmaps: torch.Tensor
305308
):
306-
"""If a heatmap center is close enough from the regressed point, the final prediction is the center of this heatmap
309+
"""If a heatmap center is close enough from the regressed point, the final
310+
prediction is the center of this heatmap
307311
308312
Args:
309313
poses: poses tensor, shape (batch_size, num_animals, num_keypoints, 2)
@@ -315,25 +319,49 @@ def _update_pose_with_heatmaps(
315319
kpt_heatmaps *= maxm
316320
batch_size, num_keypoints, h, w = kpt_heatmaps.shape
317321
kpt_heatmaps = kpt_heatmaps.view(batch_size, num_keypoints, -1)
318-
val_k, ind = kpt_heatmaps.topk(self.num_animals, dim=2)
322+
_val_k, ind = kpt_heatmaps.topk(self.num_animals, dim=2)
319323

320324
x = ind % w
321325
y = (ind / w).long()
322-
heats_ind = torch.stack((x, y), dim=3)
323-
324-
for b in range(batch_size):
325-
for i in range(num_keypoints):
326-
heat_ind = heats_ind[b, i].float()
327-
pose_ind = poses[b, :, i]
328-
pose_heat_diff = pose_ind[:, None, :] - heat_ind
329-
pose_heat_diff.pow_(2)
330-
pose_heat_diff = pose_heat_diff.sum(2)
331-
pose_heat_diff.sqrt_()
332-
keep_ind = torch.argmin(pose_heat_diff, dim=1)
333-
334-
for p in range(keep_ind.shape[0]):
335-
if pose_heat_diff[p, keep_ind[p]] < self.max_absorb_distance:
336-
poses[b, p, i] = heat_ind[keep_ind[p]]
326+
heats_ind = torch.stack(
327+
(x, y), dim=3
328+
) # (batch_size, num_keypoints, num_animals, 2)
329+
330+
# Calculate differences between all pose-heat pairs
331+
# (batch_size, num_animals, num_keypoints, 1, 2) - (batch_size, 1, num_keypoints, num_animals, 2)
332+
pose_heat_diff = poses.unsqueeze(3) - heats_ind.unsqueeze(
333+
1
334+
) # (batch_size, num_animals, num_keypoints, num_animals, 2)
335+
336+
pose_heat_dist = torch.norm(
337+
pose_heat_diff, dim=-1
338+
) # (batch_size, num_animals, num_keypoints, num_animals)
339+
340+
# Find closest heat point for each pose
341+
keep_ind = torch.argmin(
342+
pose_heat_dist, dim=-1
343+
) # (batch_size, num_animals, num_keypoints)
344+
345+
# Get minimum distances for filtering
346+
min_distances = torch.gather(pose_heat_dist, 3, keep_ind.unsqueeze(-1)).squeeze(
347+
-1
348+
) # (batch_size, num_animals, num_keypoints)
349+
350+
absorb_mask = (
351+
min_distances < self.max_absorb_distance
352+
) # (batch_size, num_animals, num_keypoints)
353+
354+
# Create indices for gathering the correct heat points
355+
batch_indices = torch.arange(batch_size, device=poses.device).view(-1, 1, 1)
356+
keypoint_indices = torch.arange(num_keypoints, device=poses.device).view(
357+
1, 1, -1
358+
)
359+
360+
selected_heat_points = heats_ind[
361+
batch_indices, keypoint_indices, keep_ind
362+
] # (batch_size, num_animals, num_keypoints, 2)
363+
364+
poses = torch.where(absorb_mask.unsqueeze(-1), selected_heat_points, poses)
337365

338366
return poses
339367

@@ -358,51 +386,79 @@ def get_heat_value(
358386
2, 3
359387
) # (batch_size, num_joints, h*w)
360388

361-
# Predicted poses based on the offset can be outside of the image
389+
# Predicted poses based on the offset can be outside the image
362390
x = torch.clamp(torch.floor(pose_coords[:, :, :, 0]), 0, w - 1).long()
363391
y = torch.clamp(torch.floor(pose_coords[:, :, :, 1]), 0, h - 1).long()
364392
keypoint_poses = (y * w + x).mT # (batch, num_joints, num_individuals)
365-
heatscores = torch.gather(heatmaps_nocenter, 2, keypoint_poses)
366-
return heatscores.mT # (batch, num_individuals, num_joints)
393+
scores = torch.gather(heatmaps_nocenter, 2, keypoint_poses)
394+
return scores.mT # (batch, num_individuals, num_joints)
367395

368-
def pose_nms(self, heatmaps: torch.Tensor, poses: torch.Tensor):
396+
def pose_nms(self, poses: torch.Tensor) -> torch.Tensor:
369397
"""Non-Maximum Suppression (NMS) for regressed poses.
370398
371399
Args:
372-
heatmaps: Heatmaps tensor.
373-
poses: Pose proposals.
400+
poses: Pose proposals of shape (batch_size, num_people, num_joints, 3).
401+
The poses for each element in the batch should be sorted by score (the
402+
highest score prediction should be first).
374403
375404
Returns:
376-
None
377-
378-
Example:
379-
# Assuming you have 'heatmaps' and 'poses' tensors
380-
predictor.pose_nms(heatmaps, poses)
405+
Pose proposals after non-maximum suppression.
381406
"""
382-
pose_scores = poses[:, :, :, 2]
383-
pose_coords = poses[:, :, :, :2]
407+
batch_size, num_people, num_joints, _ = poses.shape
408+
device = poses.device
409+
if num_people == 0:
410+
return poses
411+
412+
xy = poses[:, :, :, :2]
413+
w = xy[..., 0].max(dim=-1)[0] - xy[..., 0].min(dim=-1)[0]
414+
h = xy[..., 1].max(dim=-1)[0] - xy[..., 1].min(dim=-1)[0]
415+
area = torch.clamp((w * w) + (h * h), min=1)
416+
area = (
417+
area.unsqueeze(1)
418+
.unsqueeze(3)
419+
.expand(batch_size, num_people, num_people, num_joints)
420+
)
421+
422+
# compute the difference between keypoints
423+
pose_diff = xy.unsqueeze(2) - xy.unsqueeze(1)
424+
pose_diff.pow_(2)
425+
426+
# Compute error between people pairs
427+
pose_dist = pose_diff.sum(dim=-1)
428+
pose_dist.sqrt_()
429+
430+
pose_thresh = self.nms_threshold * torch.sqrt(area)
431+
pose_dist = (pose_dist < pose_thresh).sum(dim=-1)
432+
nms_pose = pose_dist > self.nms_threshold # shape (b, num_people, num_people)
384433

385-
if pose_coords.shape[1] == 0:
386-
return [], []
434+
# Upper triangular mask matrix to avoid double processing
435+
triu_mask = torch.triu(
436+
torch.ones(num_people, num_people, device=device), diagonal=1
437+
).bool()
387438

388-
batch_size, num_people, num_joints, _ = pose_coords.shape
389-
heatvals = self.get_heat_value(pose_coords, heatmaps)
390-
heat_score = (torch.sum(heatvals, dim=1) / num_joints)[:, 0]
439+
suppress_pairs = nms_pose & triu_mask.unsqueeze(
440+
0
441+
) # (batch_size, num_people, num_people)
391442

392-
# return heat_score
393-
# pose_score = pose_score*heatvals
394-
# poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)
443+
# For each batch, determine which poses to suppress
444+
suppressed = suppress_pairs.any(dim=1) # (batch_size, num_people)
445+
446+
kept = ~suppressed # (batch_size, num_people)
447+
448+
# Indices for reordering
449+
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
450+
people_indices = (
451+
torch.arange(num_people, device=device).unsqueeze(0).expand(batch_size, -1)
452+
)
395453

396-
# keep_pose_inds = nms_core(cfg, pose_coord, heat_score)
397-
# poses = poses[keep_pose_inds]
398-
# heat_score = heat_score[keep_pose_inds]
454+
# non-suppressed first, then suppressed
455+
sort_keys = kept.float() + (people_indices.float() + 1) / (num_people + 1)
456+
_, sort_indices = torch.sort(sort_keys, dim=1, descending=True)
399457

400-
# if len(keep_pose_inds) > cfg.DATASET.MAX_NUM_PEOPLE:
401-
# heat_score, topk_inds = torch.topk(heat_score,
402-
# cfg.DATASET.MAX_NUM_PEOPLE)
403-
# poses = poses[topk_inds]
458+
# Mask out suppressed predictions
459+
poses[~kept] = -1
404460

405-
# poses = [poses.numpy()]
406-
# scores = [i[:, 2].mean() for i in poses[0]]
461+
# Re-order predictions so the non-suppressed ones are up top
462+
poses = poses[batch_indices, sort_indices]
407463

408-
# return poses, scores
464+
return poses

0 commit comments

Comments
 (0)