updates
This commit is contained in:
+2
-1
@@ -293,13 +293,14 @@ class FocalLoss(nn.Module):
|
||||
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
||||
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
|
||||
super(FocalLoss, self).__init__()
|
||||
loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||
self.loss_fcn = loss_fcn
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, target):
|
||||
loss = self.loss_fcn(input, target, reduction='none')
|
||||
loss = self.loss_fcn(input, target)
|
||||
pt = torch.exp(-loss)
|
||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
||||
|
||||
|
||||
Reference in New Issue
Block a user