This commit is contained in:
Glenn Jocher
2019-08-29 18:58:09 +02:00
parent 894fc1c47f
commit 12b169158f
2 changed files with 24 additions and 1 deletions
+1 -1
View File
@@ -291,7 +291,7 @@ def wh_iou(box1, box2):
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
def __init__(self, loss_fcn, alpha=1, gamma=0.5, reduction='mean'):
super(FocalLoss, self).__init__()
loss_fcn.reduction = 'none' # required to apply FL to each element
self.loss_fcn = loss_fcn