diff --git a/utils/utils.py b/utils/utils.py index ecb669c7..b6b15b84 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -495,78 +495,75 @@ def build_targets(model, targets): def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False): """ - Removes detections with lower object confidence score than 'conf_thres' - Non-Maximum Suppression to further filter detections. + Performs Non-Maximum Suppression on inference results Returns detections with shape: - (x1, y1, x2, y2, object_conf, conf, class) + nx6 (x1, y1, x2, y2, conf, cls) """ - # NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch' # Box constraints min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height - method = 'vision' + method = 'merge' nc = prediction[0].shape[1] - 5 # number of classes multi_label &= nc > 1 # multiple labels per box output = [None] * len(prediction) - for image_i, pred in enumerate(prediction): + for xi, x in enumerate(prediction): # image index, image inference # Apply conf constraint - pred = pred[pred[:, 4] > conf_thres] + x = x[x[:, 4] > conf_thres] # Apply width-height constraint - pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)] + x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # If none remain process next image - if not pred.shape[0]: + if not x.shape[0]: continue # Compute conf - pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf + x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) - box = xywh2xyxy(pred[:, :4]) + box = xywh2xyxy(x[:, :4]) # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (pred[:, 5:] > conf_thres).nonzero().t() - pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) + i, j = (x[:, 5:] > conf_thres).nonzero().t() + x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only - conf, j = pred[:, 5:].max(1) - pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) + conf, j = x[:, 5:].max(1) + x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # Filter by class if classes: - pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)] + x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)] # Apply finite constraint - if not torch.isfinite(pred).all(): - pred = pred[torch.isfinite(pred).all(1)] + if not torch.isfinite(x).all(): + x = x[torch.isfinite(x).all(1)] # If none remain process next image - if not pred.shape[0]: + if not x.shape[0]: continue # Sort by confidence # if method == 'fast_batch': - # pred = pred[pred[:, 4].argsort(descending=True)] + # x = x[x[:, 4].argsort(descending=True)] # Batched NMS - c = pred[:, 5] * 0 if agnostic else pred[:, 5] # classes - boxes, scores = pred[:, :4].clone(), pred[:, 4] - boxes += c.view(-1, 1) * max_wh # offset boxes by class - if method == 'vision': - i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) - elif method == 'merge': # Merge NMS (boxes merged using weighted mean) + c = x[:, 5] * 0 if agnostic else x[:, 5] # classes + boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores + if method == 'merge': # Merge NMS (boxes merged using weighted mean) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix weights = (iou > iou_thres) * scores.view(-1, 1) weights /= weights.sum(0) - pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) + x[i, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) + elif method == 'vision': + i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix i = iou.max(0)[0] < iou_thres - output[image_i] = pred[i] + output[xi] = x[i] def get_yolo_layers(model):