updates
This commit is contained in:
+3
-2
@@ -312,11 +312,12 @@ class FocalLoss(nn.Module):
|
||||
return loss
|
||||
|
||||
|
||||
def compute_loss(p, targets, model, arc='default'): # predictions, targets, model
|
||||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||
h = model.hyp # hyperparameters
|
||||
arc = model.arc # # (default, uCE, uBCE) detection architectures
|
||||
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||
@@ -354,7 +355,7 @@ def compute_loss(p, targets, model, arc='default'): # predictions, targets, mod
|
||||
# with open('targets.txt', 'a') as file:
|
||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||
|
||||
if arc == 'default': # (default, uCE, uBCE) detection architectures
|
||||
if arc == 'default':
|
||||
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
||||
|
||||
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
|
||||
|
||||
Reference in New Issue
Block a user