updates
This commit is contained in:
+1
-1
@@ -279,7 +279,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
# Define criteria
|
||||
MSE = nn.MSELoss()
|
||||
CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
|
||||
BCE = nn.BCEWithLogitsLoss()
|
||||
BCE = nn.BCEWithLogitsLoss(pos_weight=ft([h['conf_bpw']]))
|
||||
|
||||
# Compute losses
|
||||
bs = p[0].shape[0] # batch size
|
||||
|
||||
Reference in New Issue
Block a user