diff --git a/utils/utils.py b/utils/utils.py index 8907b626..2e45ddb4 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -243,7 +243,7 @@ def build_targets(target, anchor_vec, nA, nC, nG): tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes for b in range(nB): - t = target[b] + t = target[b].cpu() nTb = len(t) # number of targets if nTb == 0: continue @@ -257,7 +257,7 @@ def build_targets(target, anchor_vec, nA, nC, nG): box1 = gwh box2 = anchor_vec.unsqueeze(1) - print(box1.device,box2.device) + print(box1.device, box2.device) inter_area = torch.min(box1, box2).prod(2) iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)