+1
-2
@@ -301,8 +301,7 @@ class FocalLoss(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
loss = self.loss_fcn(input, target)
|
loss = self.loss_fcn(input, target)
|
||||||
pt = torch.exp(-loss)
|
loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
||||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
|
||||||
|
|
||||||
if self.reduction == 'mean':
|
if self.reduction == 'mean':
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|||||||
Reference in New Issue
Block a user