@@ -75,6 +75,8 @@ def __init__(
7575 use_heatmap : bool = True ,
7676 keypoint_score_type : str = "combined" ,
7777 max_absorb_distance : int = 75 ,
78+ nms_threshold : float = 0.05 ,
79+ apply_pose_nms : bool = True ,
7880 ):
7981 """
8082 Args:
@@ -88,6 +90,8 @@ def __init__(
8890 applies the heatmap score to each keypoint. "center" applies the score
8991 of the center of each individual to all of its keypoints. "combined"
9092 multiplies the score of the heatmap and individual for each keypoint.
93+ nms_threshold: Threshold for NMS of pose.
94+ apply_pose_nms: Whether to apply pose NMS
9195 """
9296 super ().__init__ ()
9397 self .num_animals = num_animals
@@ -99,8 +103,9 @@ def __init__(
99103 if self .keypoint_score_type not in ("heatmap" , "center" , "combined" ):
100104 raise ValueError (f"Unknown keypoint score type: { self .keypoint_score_type } " )
101105
102- # TODO: Set as in HRNet/DEKR configs. Define as a constant.
103106 self .max_absorb_distance = max_absorb_distance
107+ self .nms_threshold = nms_threshold
108+ self .apply_pose_nms = apply_pose_nms
104109
105110 def forward (
106111 self , stride : float , outputs : dict [str , torch .Tensor ]
@@ -134,12 +139,9 @@ def forward(
134139 pose_ind , ctr_scores = self .get_top_values (center_heatmaps )
135140
136141 posemap = posemap .permute (0 , 2 , 3 , 1 ).view (batch_size , h * w , - 1 , 2 )
137- poses = torch .zeros (batch_size , pose_ind .shape [1 ], num_joints , 2 ).to (
138- ctr_scores .device
139- )
140- for i in range (batch_size ):
141- pose = posemap [i , pose_ind [i ]]
142- poses [i ] = pose
142+
143+ batch_indices = torch .arange (batch_size , device = pose_ind .device )[:, None ]
144+ poses = posemap [batch_indices , pose_ind ]
143145
144146 if self .use_heatmap :
145147 poses = self ._update_pose_with_heatmaps (poses , heatmaps [:, :- 1 ])
@@ -174,7 +176,9 @@ def forward(
174176 score = torch .clip (score , min = 0 , max = 1 )
175177
176178 poses_w_scores = torch .cat ([poses , score ], dim = 3 )
177- # self.pose_nms(heatmaps, poses_w_scores)
179+ if self .apply_pose_nms :
180+ poses_w_scores = self .pose_nms (poses_w_scores )
181+
178182 return {"poses" : poses_w_scores }
179183
180184 def get_locations (
@@ -263,7 +267,7 @@ def max_pool(self, heatmap: torch.Tensor) -> torch.Tensor:
263267 # Assuming you have 'heatmap' tensor
264268 max_pooled_heatmap = predictor.max_pool(heatmap)
265269 """
266- pool1 = torch .nn .MaxPool2d (3 , 1 , 1 ) # TODO JR 01/2026: Are these unused variables informative?
270+ pool1 = torch .nn .MaxPool2d (3 , 1 , 1 )
267271 pool2 = torch .nn .MaxPool2d (5 , 1 , 2 )
268272 pool3 = torch .nn .MaxPool2d (7 , 1 , 3 )
269273 map_size = (heatmap .shape [1 ] + heatmap .shape [2 ]) / 2.0
@@ -299,11 +303,11 @@ def get_top_values(
299303
300304 return pos_ind , scores
301305
302- ########## WIP to take heatmap into account for scoring ##########
303306 def _update_pose_with_heatmaps (
304307 self , _poses : torch .Tensor , kpt_heatmaps : torch .Tensor
305308 ):
306- """If a heatmap center is close enough from the regressed point, the final prediction is the center of this heatmap
309+ """If a heatmap center is close enough from the regressed point, the final
310+ prediction is the center of this heatmap
307311
308312 Args:
309313 poses: poses tensor, shape (batch_size, num_animals, num_keypoints, 2)
@@ -315,25 +319,49 @@ def _update_pose_with_heatmaps(
315319 kpt_heatmaps *= maxm
316320 batch_size , num_keypoints , h , w = kpt_heatmaps .shape
317321 kpt_heatmaps = kpt_heatmaps .view (batch_size , num_keypoints , - 1 )
318- val_k , ind = kpt_heatmaps .topk (self .num_animals , dim = 2 )
322+ _val_k , ind = kpt_heatmaps .topk (self .num_animals , dim = 2 )
319323
320324 x = ind % w
321325 y = (ind / w ).long ()
322- heats_ind = torch .stack ((x , y ), dim = 3 )
323-
324- for b in range (batch_size ):
325- for i in range (num_keypoints ):
326- heat_ind = heats_ind [b , i ].float ()
327- pose_ind = poses [b , :, i ]
328- pose_heat_diff = pose_ind [:, None , :] - heat_ind
329- pose_heat_diff .pow_ (2 )
330- pose_heat_diff = pose_heat_diff .sum (2 )
331- pose_heat_diff .sqrt_ ()
332- keep_ind = torch .argmin (pose_heat_diff , dim = 1 )
333-
334- for p in range (keep_ind .shape [0 ]):
335- if pose_heat_diff [p , keep_ind [p ]] < self .max_absorb_distance :
336- poses [b , p , i ] = heat_ind [keep_ind [p ]]
326+ heats_ind = torch .stack (
327+ (x , y ), dim = 3
328+ ) # (batch_size, num_keypoints, num_animals, 2)
329+
330+ # Calculate differences between all pose-heat pairs
331+ # (batch_size, num_animals, num_keypoints, 1, 2) - (batch_size, 1, num_keypoints, num_animals, 2)
332+ pose_heat_diff = poses .unsqueeze (3 ) - heats_ind .unsqueeze (
333+ 1
334+ ) # (batch_size, num_animals, num_keypoints, num_animals, 2)
335+
336+ pose_heat_dist = torch .norm (
337+ pose_heat_diff , dim = - 1
338+ ) # (batch_size, num_animals, num_keypoints, num_animals)
339+
340+ # Find closest heat point for each pose
341+ keep_ind = torch .argmin (
342+ pose_heat_dist , dim = - 1
343+ ) # (batch_size, num_animals, num_keypoints)
344+
345+ # Get minimum distances for filtering
346+ min_distances = torch .gather (pose_heat_dist , 3 , keep_ind .unsqueeze (- 1 )).squeeze (
347+ - 1
348+ ) # (batch_size, num_animals, num_keypoints)
349+
350+ absorb_mask = (
351+ min_distances < self .max_absorb_distance
352+ ) # (batch_size, num_animals, num_keypoints)
353+
354+ # Create indices for gathering the correct heat points
355+ batch_indices = torch .arange (batch_size , device = poses .device ).view (- 1 , 1 , 1 )
356+ keypoint_indices = torch .arange (num_keypoints , device = poses .device ).view (
357+ 1 , 1 , - 1
358+ )
359+
360+ selected_heat_points = heats_ind [
361+ batch_indices , keypoint_indices , keep_ind
362+ ] # (batch_size, num_animals, num_keypoints, 2)
363+
364+ poses = torch .where (absorb_mask .unsqueeze (- 1 ), selected_heat_points , poses )
337365
338366 return poses
339367
@@ -358,51 +386,79 @@ def get_heat_value(
358386 2 , 3
359387 ) # (batch_size, num_joints, h*w)
360388
361- # Predicted poses based on the offset can be outside of the image
389+ # Predicted poses based on the offset can be outside the image
362390 x = torch .clamp (torch .floor (pose_coords [:, :, :, 0 ]), 0 , w - 1 ).long ()
363391 y = torch .clamp (torch .floor (pose_coords [:, :, :, 1 ]), 0 , h - 1 ).long ()
364392 keypoint_poses = (y * w + x ).mT # (batch, num_joints, num_individuals)
365- heatscores = torch .gather (heatmaps_nocenter , 2 , keypoint_poses )
366- return heatscores .mT # (batch, num_individuals, num_joints)
393+ scores = torch .gather (heatmaps_nocenter , 2 , keypoint_poses )
394+ return scores .mT # (batch, num_individuals, num_joints)
367395
368- def pose_nms (self , heatmaps : torch .Tensor , poses : torch .Tensor ) :
396+ def pose_nms (self , poses : torch .Tensor ) -> torch .Tensor :
369397 """Non-Maximum Suppression (NMS) for regressed poses.
370398
371399 Args:
372- heatmaps: Heatmaps tensor.
373- poses: Pose proposals.
400+ poses: Pose proposals of shape (batch_size, num_people, num_joints, 3).
401+ The poses for each element in the batch should be sorted by score (the
402+ highest score prediction should be first).
374403
375404 Returns:
376- None
377-
378- Example:
379- # Assuming you have 'heatmaps' and 'poses' tensors
380- predictor.pose_nms(heatmaps, poses)
405+ Pose proposals after non-maximum suppression.
381406 """
382- pose_scores = poses [:, :, :, 2 ]
383- pose_coords = poses [:, :, :, :2 ]
407+ batch_size , num_people , num_joints , _ = poses .shape
408+ device = poses .device
409+ if num_people == 0 :
410+ return poses
411+
412+ xy = poses [:, :, :, :2 ]
413+ w = xy [..., 0 ].max (dim = - 1 )[0 ] - xy [..., 0 ].min (dim = - 1 )[0 ]
414+ h = xy [..., 1 ].max (dim = - 1 )[0 ] - xy [..., 1 ].min (dim = - 1 )[0 ]
415+ area = torch .clamp ((w * w ) + (h * h ), min = 1 )
416+ area = (
417+ area .unsqueeze (1 )
418+ .unsqueeze (3 )
419+ .expand (batch_size , num_people , num_people , num_joints )
420+ )
421+
422+ # compute the difference between keypoints
423+ pose_diff = xy .unsqueeze (2 ) - xy .unsqueeze (1 )
424+ pose_diff .pow_ (2 )
425+
426+ # Compute error between people pairs
427+ pose_dist = pose_diff .sum (dim = - 1 )
428+ pose_dist .sqrt_ ()
429+
430+ pose_thresh = self .nms_threshold * torch .sqrt (area )
431+ pose_dist = (pose_dist < pose_thresh ).sum (dim = - 1 )
432+ nms_pose = pose_dist > self .nms_threshold # shape (b, num_people, num_people)
384433
385- if pose_coords .shape [1 ] == 0 :
386- return [], []
434+ # Upper triangular mask matrix to avoid double processing
435+ triu_mask = torch .triu (
436+ torch .ones (num_people , num_people , device = device ), diagonal = 1
437+ ).bool ()
387438
388- batch_size , num_people , num_joints , _ = pose_coords . shape
389- heatvals = self . get_heat_value ( pose_coords , heatmaps )
390- heat_score = ( torch . sum ( heatvals , dim = 1 ) / num_joints )[:, 0 ]
439+ suppress_pairs = nms_pose & triu_mask . unsqueeze (
440+ 0
441+ ) # (batch_size, num_people, num_people)
391442
392- # return heat_score
393- # pose_score = pose_score*heatvals
394- # poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)
443+ # For each batch, determine which poses to suppress
444+ suppressed = suppress_pairs .any (dim = 1 ) # (batch_size, num_people)
445+
446+ kept = ~ suppressed # (batch_size, num_people)
447+
448+ # Indices for reordering
449+ batch_indices = torch .arange (batch_size , device = device ).unsqueeze (1 )
450+ people_indices = (
451+ torch .arange (num_people , device = device ).unsqueeze (0 ).expand (batch_size , - 1 )
452+ )
395453
396- # keep_pose_inds = nms_core(cfg, pose_coord, heat_score)
397- # poses = poses[keep_pose_inds]
398- # heat_score = heat_score[keep_pose_inds]
454+ # non-suppressed first, then suppressed
455+ sort_keys = kept . float () + ( people_indices . float () + 1 ) / ( num_people + 1 )
456+ _ , sort_indices = torch . sort ( sort_keys , dim = 1 , descending = True )
399457
400- # if len(keep_pose_inds) > cfg.DATASET.MAX_NUM_PEOPLE:
401- # heat_score, topk_inds = torch.topk(heat_score,
402- # cfg.DATASET.MAX_NUM_PEOPLE)
403- # poses = poses[topk_inds]
458+ # Mask out suppressed predictions
459+ poses [~ kept ] = - 1
404460
405- # poses = [poses.numpy()]
406- # scores = [i[:, 2].mean() for i in poses[0] ]
461+ # Re-order predictions so the non-suppressed ones are up top
462+ poses = poses [ batch_indices , sort_indices ]
407463
408- # return poses, scores
464+ return poses
0 commit comments