From ef4dde425bd0603de373389dad55c43c5d6c219e Mon Sep 17 00:00:00 2001 From: "Kin-Yiu, Wong" <102582011@cc.ncu.edu.tw> Date: Thu, 14 Jul 2022 16:49:38 +0800 Subject: [PATCH] main code fix gain for train_aux --- utils/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 6448cacf..6504301a 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -1115,7 +1115,7 @@ class ComputeLossBinOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices @@ -1561,7 +1561,7 @@ class ComputeLossAuxOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices @@ -1614,7 +1614,7 @@ class ComputeLossAuxOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices