This commit is contained in:
Glenn Jocher
2020-01-01 12:44:33 -08:00
parent 935bbfcc2b
commit d92b75aec8
2 changed files with 10 additions and 5 deletions
+5 -1
View File
@@ -498,7 +498,7 @@ def build_targets(model, targets):
return tcls, tbox, indices, av
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
"""
Removes detections with lower object confidence score than 'conf_thres'
Non-Maximum Suppression to further filter detections.
@@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
conf, j = pred[:, 5:].max(1)
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
# Filter by class
if classes:
pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)]
# Apply finite constraint
if not torch.isfinite(pred).all():
pred = pred[torch.isfinite(pred).all(1)]