NMS and test batch_size updates

This commit is contained in:
Glenn Jocher
2020-03-29 20:41:32 -07:00
parent c6b59a0e8a
commit eb151a881e
2 changed files with 10 additions and 8 deletions
+7 -5
View File
@@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
x = x[torch.isfinite(x).all(1)]
# If none remain process next image
if not x.shape[0]:
n = x.shape[0] # number of boxes
if not n:
continue
# Sort by confidence
@@ -555,10 +556,11 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0)
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
if n < 1000: # update boxes
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0)
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact