diff --git a/test.py b/test.py index bdcf8cf5..dc6c9f50 100644 --- a/test.py +++ b/test.py @@ -126,7 +126,7 @@ def test(cfg, # Assign all predictions as incorrect correct = torch.zeros(len(pred), niou) if nl: - detected = [] + detected = [] # target indices tcls_tensor = labels[:, 0] # target boxes @@ -134,26 +134,24 @@ def test(cfg, tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height - # Search for correct predictions - for i, (*pbox, _, pcls) in enumerate(pred): + # Per target class + for cls in torch.unique(tcls_tensor): + ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices + pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices - # Break if all targets already located in image - if len(detected) == nl: - break + # Search for detections + if len(pi): + # Prediction to target ious + ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices - # Continue if predicted class not among image classes - if pcls.item() not in tcls: - continue - - # Best iou, index between pred and targets - m = (pcls == tcls_tensor).nonzero().view(-1) - iou, j = bbox_iou(pbox, tbox[m]).max(0) - m = m[j] - - # Per iou_thres 'correct' vector - if iou > iou_thres[0] and m not in detected: - detected.append(m) - correct[i] = iou > iou_thres + # Append detections + for j in (ious > iou_thres[0]).nonzero(): + d = ti[i[j]] # detected target + if d not in detected: + detected.append(d) + correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn + if len(detected) == nl: # all targets already located in image + break # Append statistics (correct, conf, pcls, tcls) stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))