updates
This commit is contained in:
+13
-25
@@ -464,7 +464,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
Returns detections with shape:
|
||||
(x1, y1, x2, y2, object_conf, class_conf, class)
|
||||
(x1, y1, x2, y2, object_conf, conf, class)
|
||||
"""
|
||||
# NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
|
||||
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
|
||||
@@ -474,47 +474,35 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
# Experiment: Prior class size rejection
|
||||
# x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
|
||||
# a = w * h # area
|
||||
# ar = w / (h + 1e-16) # aspect ratio
|
||||
# n = len(w)
|
||||
# log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
|
||||
# shape_likelihood = np.zeros((n, 60), dtype=np.float32)
|
||||
# x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1)
|
||||
# from scipy.stats import multivariate_normal
|
||||
# for c in range(60):
|
||||
# shape_likelihood[:, c] =
|
||||
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
|
||||
# Duplicate ambiguous
|
||||
# b = pred[pred[:, 5:].sum(1) > 1.1]
|
||||
# if len(b):
|
||||
# b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0
|
||||
# pred = torch.cat((pred, b), 0)
|
||||
|
||||
# Multiply conf by class conf to get combined confidence
|
||||
class_conf, class_pred = pred[:, 5:].max(1)
|
||||
pred[:, 4] *= class_conf
|
||||
conf, cls = pred[:, 4:].max(1)
|
||||
|
||||
# # Merge classes (optional)
|
||||
# class_pred[(class_pred.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
|
||||
# cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
|
||||
#
|
||||
# # Remove classes (optional)
|
||||
# pred[class_pred != 2, 4] = 0.0
|
||||
# pred[cls != 2, 4] = 0.0
|
||||
|
||||
# Select only suitable predictions
|
||||
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & \
|
||||
torch.isfinite(pred).all(1)
|
||||
i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(
|
||||
pred).all(1)
|
||||
pred = pred[i]
|
||||
|
||||
# If none are remaining => process next image
|
||||
if len(pred) == 0:
|
||||
continue
|
||||
|
||||
# Select predicted classes
|
||||
class_conf = class_conf[i]
|
||||
class_pred = class_pred[i].unsqueeze(1).float()
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||
|
||||
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
|
||||
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
|
||||
# Detections ordered as (x1y1x2y2, conf, cls)
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
pred = pred[(-pred[:, 4]).argsort()]
|
||||
|
||||
Reference in New Issue
Block a user