diff --git a/utils/utils.py b/utils/utils.py index ae52c4da..7c4c0c81 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -349,7 +349,8 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1) - v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here + # v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here + v = pred[:, 4] > conf_thres v = v.nonzero().squeeze() if len(v.shape) == 0: v = v.unsqueeze(0)