diff --git a/utils/loss.py b/utils/loss.py index 6448cacf..6504301a 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -1115,7 +1115,7 @@ class ComputeLossBinOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices @@ -1561,7 +1561,7 @@ class ComputeLossAuxOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices @@ -1614,7 +1614,7 @@ class ComputeLossAuxOTA: # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices