updates
This commit is contained in:
@@ -17,7 +17,7 @@ def select_device(force_cpu=False):
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print('Found %g GPUs' % torch.cuda.device_count())
|
||||
print('WARNING Using GPU0 Only: https://github.com/ultralytics/yolov3/issues/21')
|
||||
print('WARNING Multi-GPU Issue: https://github.com/ultralytics/yolov3/issues/21')
|
||||
# torch.cuda.set_device(0) # OPTIONAL: Set your GPU if multiple available
|
||||
# # print('Using ', torch.cuda.device_count(), ' GPUs')
|
||||
|
||||
|
||||
@@ -242,12 +242,6 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||
|
||||
if anchor_vec.is_cuda:
|
||||
txy = txy.cuda()
|
||||
twh = twh.cuda()
|
||||
tconf = tconf.cuda()
|
||||
tcls = tcls.cuda()
|
||||
|
||||
for b in range(nB):
|
||||
t = target[b]
|
||||
nTb = len(t) # number of targets
|
||||
@@ -263,8 +257,6 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||
box1 = gwh
|
||||
box2 = anchor_vec.unsqueeze(1)
|
||||
|
||||
print(box1.device, box2.device)
|
||||
|
||||
inter_area = torch.min(box1, box2).prod(2)
|
||||
iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user