updates
This commit is contained in:
+5
-1
@@ -498,7 +498,7 @@ def build_targets(model, targets):
|
||||
return tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'):
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
@@ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||
conf, j = pred[:, 5:].max(1)
|
||||
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
|
||||
# Filter by class
|
||||
if classes:
|
||||
pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
if not torch.isfinite(pred).all():
|
||||
pred = pred[torch.isfinite(pred).all(1)]
|
||||
|
||||
Reference in New Issue
Block a user