From 0fe246f3995dc77367d6c8d6158c9112ff8a309f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 2 Dec 2019 18:22:21 -0800 Subject: [PATCH] updates --- utils/utils.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 12083988..3984d20b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,4 +1,5 @@ import glob +import math import os import random import shutil @@ -10,8 +11,8 @@ import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn +import torchvision from tqdm import tqdm -import math from . import torch_utils # , google_utils @@ -503,7 +504,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): # Box (center x, center y, width, height) to (x1, y1, x2, y2) pred[:, :4] = xywh2xyxy(pred[:, :4]) - # pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551 # Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred) pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1) @@ -511,8 +511,21 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): # Get detections sorted by decreasing confidence scores pred = pred[(-pred[:, 4]).argsort()] + # Set NMS method https://github.com/ultralytics/yolov3/issues/679 + # 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED' + method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest + + # Batched NMS + if method == 'VISION_BATCHED': + i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4], + scores=pred[:, 4], + idxs=pred[:, 6], + iou_threshold=nms_thres) + output[image_i] = pred[i] + continue + + # Non-maximum suppression det_max = [] - nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in pred[:, -1].unique(): dc = pred[pred[:, -1] == c] # select class c n = len(dc) @@ -520,10 +533,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): det_max.append(dc) # No NMS required if only 1 prediction continue elif n > 500: - dc = dc[:500] # limit to first 100 boxes: https://github.com/ultralytics/yolov3/issues/117 + dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117 - # Non-maximum suppression - if nms_style == 'OR': # default + if method == 'VISION': + i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres) + det_max.append(dc[i]) + + elif method == 'OR': # default # METHOD1 # ind = list(range(len(dc))) # while len(ind): @@ -540,14 +556,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes dc = dc[1:][iou < nms_thres] # remove ious > threshold - elif nms_style == 'AND': # requires overlap, single boxes erased + elif method == 'AND': # requires overlap, single boxes erased while len(dc) > 1: iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes if iou.max() > 0.5: det_max.append(dc[:1]) dc = dc[1:][iou < nms_thres] # remove ious > threshold - elif nms_style == 'MERGE': # weighted mixture box + elif method == 'MERGE': # weighted mixture box while len(dc): if len(dc) == 1: det_max.append(dc) @@ -558,7 +574,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): det_max.append(dc[:1]) dc = dc[i == 0] - elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503 + elif method == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503 sigma = 0.5 # soft-nms sigma parameter while len(dc): if len(dc) == 1: