diff --git a/test.py b/test.py index 4f70d297..d85c5162 100644 --- a/test.py +++ b/test.py @@ -133,18 +133,29 @@ def test( stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) # Compute statistics - stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))] - nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class - if len(stats_np): - p, r, ap, f1, ap_class = ap_per_class(*stats_np) + stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy + nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class + if len(stats): + p, r, ap, f1, ap_class = ap_per_class(*stats) mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean() + if any(r > 1): + chkpt = {'epoch': -1, + 'best_loss': None, + 'model': model.module.state_dict() if type( + model) is nn.parallel.DistributedDataParallel else model.state_dict(), + 'optimizer': None} + + # Save problem checkpoint + torch.save(chkpt, 'recall_issue.pt') + del chkpt + # Print results pf = '%20s' + '%10.3g' * 6 # print format print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n') # Print results per class - if nc > 1 and len(stats_np): + if nc > 1 and len(stats): for i, c in enumerate(ap_class): print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))