This commit is contained in:
Glenn Jocher
2019-06-15 02:10:15 +02:00
parent bb3682024e
commit 02291622fa
2 changed files with 8 additions and 6 deletions
+1 -1
View File
@@ -279,7 +279,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Define criteria
MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
BCE = nn.BCEWithLogitsLoss()
BCE = nn.BCEWithLogitsLoss(pos_weight=ft([h['conf_bpw']]))
# Compute losses
bs = p[0].shape[0] # batch size