This commit is contained in:
Glenn Jocher
2019-04-15 13:55:52 +02:00
parent 3c6b168a0a
commit 1191dee71b
2 changed files with 9 additions and 17 deletions
+2 -9
View File
@@ -1,6 +1,5 @@
import glob
import random
from collections import defaultdict
import cv2
import matplotlib
@@ -244,7 +243,7 @@ def wh_iou(box1, box2):
def compute_loss(p, targets): # predictions, targets
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0])
txy, twh, tcls, indices = targets
MSE = nn.MSELoss()
@@ -274,13 +273,7 @@ def compute_loss(p, targets): # predictions, targets
lconf += (k * 64) * BCE(pi0[..., 4], tconf) # obj_conf loss
loss = lxy + lwh + lconf + lcls
# Add to dictionary
d = defaultdict(float)
losses = [loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item()]
for k, v in zip(['total', 'xy', 'wh', 'conf', 'cls'], losses):
d[k] = v
return loss, d
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
def build_targets(model, targets):