diff --git a/test.py b/test.py index 3e739d8e..4d68438e 100644 --- a/test.py +++ b/test.py @@ -145,8 +145,8 @@ def test(cfg, # Per target class for cls in torch.unique(tcls_tensor): - ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices - pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices + ti = (cls == tcls_tensor).nonzero().view(-1) # target indices + pi = (cls == pred[:, 5]).nonzero().view(-1) # prediction indices # Search for detections if pi.shape[0]: