updates
This commit is contained in:
+9
-5
@@ -492,9 +492,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
# Remove rows
|
||||
pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold
|
||||
pred = pred[pred[:, 4] > conf_thres] # retain above threshold
|
||||
|
||||
# Select only suitable predictions
|
||||
# compute conf
|
||||
torch.sigmoid_(pred[..., 5:])
|
||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
||||
|
||||
# Apply width-height constraint
|
||||
i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1)
|
||||
pred = pred[i]
|
||||
|
||||
@@ -507,10 +511,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||
|
||||
# Multi-class
|
||||
if multi_cls or conf_thres < 0.01:
|
||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((pred[i, :4], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
else: # best class only
|
||||
conf, j = pred[:, 4:].max(1)
|
||||
conf, j = pred[:, 5:].max(1)
|
||||
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
|
||||
Reference in New Issue
Block a user