diff --git a/utils/utils.py b/utils/utils.py index 929f6ba4..82b75d41 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -349,7 +349,8 @@ class FocalLoss(nn.Module): def forward(self, pred, true): loss = self.loss_fcn(pred, true) - # loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py pred_prob = torch.sigmoid(pred) # prob from logits