This commit is contained in:
Glenn Jocher
2019-08-18 02:02:04 +02:00
parent 0aece25ef6
commit 43230c48bf
2 changed files with 39 additions and 38 deletions
+2 -1
View File
@@ -293,13 +293,14 @@ class FocalLoss(nn.Module):
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
loss_fcn.reduction = 'none' # required to apply FL to each element
self.loss_fcn = loss_fcn
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input, target):
loss = self.loss_fcn(input, target, reduction='none')
loss = self.loss_fcn(input, target)
pt = torch.exp(-loss)
loss *= self.alpha * (1 - pt) ** self.gamma