This commit is contained in:
Glenn Jocher
2019-01-02 16:32:38 +01:00
parent 7283f26f6f
commit b181c61f4b
3 changed files with 37 additions and 37 deletions
+27 -29
View File
@@ -309,8 +309,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# cross-class NMS (experimental)
cross_class_nms = False
if cross_class_nms:
# thresh = 0.85
thresh = nms_thres
a = pred.clone()
_, indices = torch.sort(-a[:, 4], 0) # sort best to worst
a = a[indices]
@@ -325,7 +323,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if len(close) > 0:
close = close + i + 1
iou = bbox_iou(a[i:i + 1, :4], a[close.squeeze(), :4].reshape(-1, 4), x1y1x2y2=False)
bad = close[iou > thresh]
bad = close[iou > nms_thres]
if len(bad) > 0:
mask = torch.ones(len(a)).type(torch.ByteTensor)
@@ -333,13 +331,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
a = a[mask]
pred = a
x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
a = w * h # area
ar = w / (h + 1e-16) # aspect ratio
log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# Experiment: Prior class size rejection
# x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
# a = w * h # area
# ar = w / (h + 1e-16) # aspect ratio
# n = len(w)
# log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# shape_likelihood = np.zeros((n, 60), dtype=np.float32)
# x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1)
# from scipy.stats import multivariate_normal
@@ -348,7 +345,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
v = ((pred[:, 4] > conf_thres) & (class_prob > .3))
v = ((pred[:, 4] > conf_thres) & (class_prob > .3)) # TODO examine arbitrary 0.3 thres here
v = v.nonzero().squeeze()
if len(v.shape) == 0:
v = v.unsqueeze(0)
@@ -375,44 +372,43 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
nms_style = 'OR' # 'AND' or 'OR' (classical)
for c in unique_labels:
# Get the detections with the particular class
detections_class = detections[detections[:, -1] == c]
det_class = detections[detections[:, -1] == c]
# Sort the detections by maximum objectness confidence
_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
detections_class = detections_class[conf_sort_index]
_, conf_sort_index = torch.sort(det_class[:, 4], descending=True)
det_class = det_class[conf_sort_index]
# Perform non-maximum suppression
max_detections = []
det_max = []
if nms_style == 'OR': # Classical NMS
while detections_class.shape[0]:
while det_class.shape[0]:
# Get detection with highest confidence and save as max detection
max_detections.append(detections_class[0].unsqueeze(0))
det_max.append(det_class[0].unsqueeze(0))
# Stop if we're at the last detection
if len(detections_class) == 1:
if len(det_class) == 1:
break
# Get the IOUs for all boxes with lower confidence
ious = bbox_iou(max_detections[-1], detections_class[1:])
ious = bbox_iou(det_max[-1], det_class[1:])
# Remove detections with IoU >= NMS threshold
detections_class = detections_class[1:][ious < nms_thres]
det_class = det_class[1:][ious < nms_thres]
elif nms_style == 'AND': # 'AND'-style NMS, at least two boxes must share commonality to pass, single boxes erased
while detections_class.shape[0]:
if len(detections_class) == 1:
elif nms_style == 'AND': # 'AND'-style NMS: >=2 boxes must share commonality to pass, single boxes erased
while det_class.shape[0]:
if len(det_class) == 1:
break
ious = bbox_iou(detections_class[:1], detections_class[1:])
ious = bbox_iou(det_class[:1], det_class[1:])
if ious.max() > 0.5:
max_detections.append(detections_class[0].unsqueeze(0))
det_max.append(det_class[0].unsqueeze(0))
# Remove detections with IoU >= NMS threshold
detections_class = detections_class[1:][ious < nms_thres]
det_class = det_class[1:][ious < nms_thres]
if len(max_detections) > 0:
max_detections = torch.cat(max_detections).data
if len(det_max) > 0:
det_max = torch.cat(det_max).data
# Add max detections to outputs
output[image_i] = max_detections if output[image_i] is None else torch.cat(
(output[image_i], max_detections))
output[image_i] = det_max if output[image_i] is None else torch.cat((output[image_i], det_max))
return output
@@ -426,6 +422,7 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
def coco_class_count(path='../coco/labels/train2014/'):
# histogram of occurrences per class
import glob
nC = 80 # number classes
@@ -443,6 +440,7 @@ def plot_results():
import numpy as np
import matplotlib.pyplot as plt
# import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt')
plt.figure(figsize=(16, 8))
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
files = sorted(glob.glob('results*.txt'))