This commit is contained in:
Glenn Jocher
2019-12-19 18:09:13 -08:00
parent ad73ce4334
commit f5cd3596f5
4 changed files with 30 additions and 37 deletions
+13 -25
View File
@@ -464,7 +464,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
Removes detections with lower object confidence score than 'conf_thres'
Non-Maximum Suppression to further filter detections.
Returns detections with shape:
(x1, y1, x2, y2, object_conf, class_conf, class)
(x1, y1, x2, y2, object_conf, conf, class)
"""
# NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
@@ -474,47 +474,35 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
output = [None] * len(prediction)
for image_i, pred in enumerate(prediction):
# Experiment: Prior class size rejection
# x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
# a = w * h # area
# ar = w / (h + 1e-16) # aspect ratio
# n = len(w)
# log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# shape_likelihood = np.zeros((n, 60), dtype=np.float32)
# x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1)
# from scipy.stats import multivariate_normal
# for c in range(60):
# shape_likelihood[:, c] =
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
# Duplicate ambiguous
# b = pred[pred[:, 5:].sum(1) > 1.1]
# if len(b):
# b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0
# pred = torch.cat((pred, b), 0)
# Multiply conf by class conf to get combined confidence
class_conf, class_pred = pred[:, 5:].max(1)
pred[:, 4] *= class_conf
conf, cls = pred[:, 4:].max(1)
# # Merge classes (optional)
# class_pred[(class_pred.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
# cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
#
# # Remove classes (optional)
# pred[class_pred != 2, 4] = 0.0
# pred[cls != 2, 4] = 0.0
# Select only suitable predictions
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & \
torch.isfinite(pred).all(1)
i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(
pred).all(1)
pred = pred[i]
# If none are remaining => process next image
if len(pred) == 0:
continue
# Select predicted classes
class_conf = class_conf[i]
class_pred = class_pred[i].unsqueeze(1).float()
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
pred[:, :4] = xywh2xyxy(pred[:, :4])
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
# Detections ordered as (x1y1x2y2, conf, cls)
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
# Get detections sorted by decreasing confidence scores
pred = pred[(-pred[:, 4]).argsort()]