updates
This commit is contained in:
+15
-12
@@ -322,8 +322,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
|
||||
FBCE = nn.BCEWithLogitsLoss()
|
||||
FCE = nn.CrossEntropyLoss() # weight=model.class_weights
|
||||
BCE = nn.BCEWithLogitsLoss()
|
||||
CE = nn.CrossEntropyLoss() # weight=model.class_weights
|
||||
|
||||
if 'F' in arc: # add focal loss
|
||||
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls), FocalLoss(BCEobj), FocalLoss(BCE), FocalLoss(CE)
|
||||
|
||||
# Compute losses
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
@@ -343,7 +346,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
lbox += (1.0 - giou).mean() # giou loss
|
||||
|
||||
if arc == 'default' and model.nc > 1: # cls loss (only if multiple classes)
|
||||
if 'default' in arc and model.nc > 1: # cls loss (only if multiple classes)
|
||||
t = torch.zeros_like(ps[:, 5:]) # targets
|
||||
t[range(nb), tcls[i]] = 1.0
|
||||
lcls += BCEcls(ps[:, 5:], t) # BCE
|
||||
@@ -353,20 +356,20 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
# with open('targets.txt', 'a') as file:
|
||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||
|
||||
if arc == 'default':
|
||||
if 'default' in arc: # seperate obj and cls
|
||||
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
||||
|
||||
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
|
||||
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
|
||||
if nb:
|
||||
t[b, a, gj, gi] = tcls[i] + 1
|
||||
lcls += FCE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
|
||||
|
||||
elif arc == 'uBCE': # unified BCE (1 background + 80 classes), hyps 200-30
|
||||
elif 'BCE' in arc: # unified BCE (80 classes)
|
||||
t = torch.zeros_like(pi[..., 5:]) # targets
|
||||
if nb:
|
||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||
lobj += FBCE(pi[..., 5:], t)
|
||||
lobj += BCE(pi[..., 5:], t)
|
||||
|
||||
elif 'CE' in arc: # unified CE (1 background + 80 classes)
|
||||
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
|
||||
if nb:
|
||||
t[b, a, gj, gi] = tcls[i] + 1
|
||||
lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
|
||||
|
||||
lbox *= h['giou']
|
||||
lobj *= h['obj']
|
||||
|
||||
Reference in New Issue
Block a user