updates
This commit is contained in:
+4
-2
@@ -499,7 +499,7 @@ def build_targets(model, targets):
|
||||
return tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None):
|
||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, classes=None, agnostic=False):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
@@ -511,6 +511,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||
# Box constraints
|
||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||
|
||||
method = 'vision_batch'
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
# Apply conf constraint
|
||||
@@ -548,7 +549,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||
|
||||
# Batched NMS
|
||||
if method == 'vision_batch':
|
||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], iou_thres)]
|
||||
c = j * 0 if agnostic else j # class-agnostic NMS
|
||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)]
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
|
||||
Reference in New Issue
Block a user