updates
This commit is contained in:
+7
-7
@@ -497,7 +497,7 @@ def build_targets(model, targets):
|
||||
return tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision_batch'):
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
@@ -542,7 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
|
||||
# Batched NMS
|
||||
if method == 'vision_batch':
|
||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
|
||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], iou_thres)]
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
@@ -562,7 +562,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
||||
|
||||
if method == 'vision':
|
||||
det_max.append(dc[torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres)])
|
||||
det_max.append(dc[torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], iou_thres)])
|
||||
|
||||
elif method == 'or': # default
|
||||
# METHOD1
|
||||
@@ -570,7 +570,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
# while len(ind):
|
||||
# j = ind[0]
|
||||
# det_max.append(dc[j:j + 1]) # save highest conf detection
|
||||
# reject = (bbox_iou(dc[j], dc[ind]) > nms_thres).nonzero()
|
||||
# reject = (bbox_iou(dc[j], dc[ind]) > iou_thres).nonzero()
|
||||
# [ind.pop(i) for i in reversed(reject)]
|
||||
|
||||
# METHOD2
|
||||
@@ -579,21 +579,21 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
if len(dc) == 1: # Stop if we're at the last detection
|
||||
break
|
||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
dc = dc[1:][iou < iou_thres] # remove ious > threshold
|
||||
|
||||
elif method == 'and': # requires overlap, single boxes erased
|
||||
while len(dc) > 1:
|
||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||
if iou.max() > 0.5:
|
||||
det_max.append(dc[:1])
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
dc = dc[1:][iou < iou_thres] # remove ious > threshold
|
||||
|
||||
elif method == 'merge': # weighted mixture box
|
||||
while len(dc):
|
||||
if len(dc) == 1:
|
||||
det_max.append(dc)
|
||||
break
|
||||
i = bbox_iou(dc[0], dc) > nms_thres # iou with other boxes
|
||||
i = bbox_iou(dc[0], dc) > iou_thres # iou with other boxes
|
||||
weights = dc[i, 4:5]
|
||||
dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
|
||||
det_max.append(dc[:1])
|
||||
|
||||
Reference in New Issue
Block a user