diff --git a/utils/utils.py b/utils/utils.py index dcb684b2..6f3dfabf 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -372,8 +372,8 @@ def compute_loss(p, targets, model): # predictions, targets, model # Define criteria BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red) BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red) - BCE = nn.BCEWithLogitsLoss() - CE = nn.CrossEntropyLoss() # weight=model.class_weights + BCE = nn.BCEWithLogitsLoss(reduction=red) + CE = nn.CrossEntropyLoss(reduction=red) # weight=model.class_weights if 'F' in arc: # add focal loss g = h['fl_gamma']