This commit is contained in:
Glenn Jocher
2019-08-23 17:18:59 +02:00
parent d2ef817b1f
commit 5f2b551818
3 changed files with 23 additions and 22 deletions
+6 -6
View File
@@ -312,7 +312,7 @@ class FocalLoss(nn.Module):
return loss
def compute_loss(p, targets, model): # predictions, targets, model
def compute_loss(p, targets, model, arc='default'): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
@@ -321,12 +321,12 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
# CE = nn.CrossEntropyLoss(weight=model.class_weights)
BCE = nn.BCEWithLogitsLoss()
CE = nn.CrossEntropyLoss() # weight=model.class_weights
# Compute losses
bs = p[0].shape[0] # batch size
k = bs / 64 # loss gain
arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0]) # target obj
@@ -344,7 +344,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lbox += (1.0 - giou).mean() # giou loss
if arc == 'normal' and model.nc > 1: # cls loss (only if multiple classes)
if arc == 'default' and model.nc > 1: # cls loss (only if multiple classes)
t = torch.zeros_like(ps[:, 5:]) # targets
t[range(nb), tcls[i]] = 1.0
lcls += BCEcls(ps[:, 5:], t) # BCE
@@ -354,7 +354,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
if arc == 'normal':
if arc == 'default': # (default, uCE, uBCE) detection architectures
lobj += BCEobj(pi[..., 4], tobj) # obj loss
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
@@ -367,7 +367,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
t = torch.zeros_like(pi[..., 5:]) # targets
if nb:
t[b, a, gj, gi, tcls[i]] = 1.0
lobj += BCEobj(pi[..., 5:], t)
lobj += BCE(pi[..., 5:], t)
lbox *= k * h['giou']
lobj *= k * h['obj']