updates
This commit is contained in:
+2
-9
@@ -1,6 +1,5 @@
|
||||
import glob
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
@@ -244,7 +243,7 @@ def wh_iou(box1, box2):
|
||||
|
||||
|
||||
def compute_loss(p, targets): # predictions, targets
|
||||
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
|
||||
FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor
|
||||
lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0])
|
||||
txy, twh, tcls, indices = targets
|
||||
MSE = nn.MSELoss()
|
||||
@@ -274,13 +273,7 @@ def compute_loss(p, targets): # predictions, targets
|
||||
lconf += (k * 64) * BCE(pi0[..., 4], tconf) # obj_conf loss
|
||||
loss = lxy + lwh + lconf + lcls
|
||||
|
||||
# Add to dictionary
|
||||
d = defaultdict(float)
|
||||
losses = [loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item()]
|
||||
for k, v in zip(['total', 'xy', 'wh', 'conf', 'cls'], losses):
|
||||
d[k] = v
|
||||
|
||||
return loss, d
|
||||
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
|
||||
|
||||
|
||||
def build_targets(model, targets):
|
||||
|
||||
Reference in New Issue
Block a user