NMS and test batch_size updates
This commit is contained in:
+7
-5
@@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||
x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n:
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
@@ -555,10 +556,11 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
if n < 1000: # update boxes
|
||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
elif method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||
|
||||
Reference in New Issue
Block a user