updates
This commit is contained in:
+7
-1
@@ -242,8 +242,14 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||
|
||||
if anchor_vec.is_cuda():
|
||||
txy = txy.cuda()
|
||||
twh = twh.cuda()
|
||||
tconf = tconf.cuda()
|
||||
tcls = tcls.cuda()
|
||||
|
||||
for b in range(nB):
|
||||
t = target[b].cpu()
|
||||
t = target[b]
|
||||
nTb = len(t) # number of targets
|
||||
if nTb == 0:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user