updates
This commit is contained in:
+4
-3
@@ -291,12 +291,12 @@ def wh_iou(box1, box2):
|
||||
class FocalLoss(nn.Module):
|
||||
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
|
||||
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
||||
def __init__(self, loss_fcn, alpha=1, gamma=0.5, reduction='mean'):
|
||||
def __init__(self, loss_fcn, gamma=0.5, alpha=1, reduction='mean'):
|
||||
super(FocalLoss, self).__init__()
|
||||
loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||
self.loss_fcn = loss_fcn
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, target):
|
||||
@@ -325,7 +325,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
CE = nn.CrossEntropyLoss() # weight=model.class_weights
|
||||
|
||||
if 'F' in arc: # add focal loss
|
||||
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls), FocalLoss(BCEobj), FocalLoss(BCE), FocalLoss(CE)
|
||||
g = h['fl_gamma']
|
||||
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g), FocalLoss(BCE, g), FocalLoss(CE, g)
|
||||
|
||||
# Compute losses
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
|
||||
Reference in New Issue
Block a user