diff --git a/test.py b/test.py index ec9d717d..b260714e 100644 --- a/test.py +++ b/test.py @@ -60,6 +60,7 @@ def test( for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc='Computing mAP')): targets = targets.to(device) imgs = imgs.to(device) + _, _, height, width = imgs.shape # batch size, channels, height, width # Plot images with bounding boxes if batch_i == 0 and not os.path.exists('test_batch0.jpg'): @@ -108,7 +109,12 @@ def test( correct = [0] * len(pred) if nl: detected = [] - tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes + tcls_tensor = labels[:, 0] + + # target boxes + tbox = xywh2xyxy(labels[:, 1:5]) + tbox[[0, 2]] *= width + tbox[[1, 3]] *= height # Search for correct predictions for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred):