main code
fix gain for train_aux
This commit is contained in:
parent
4cebf4000a
commit
ef4dde425b
@ -1115,7 +1115,7 @@ class ComputeLossBinOTA:
|
||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||
indices, anch = [], []
|
||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||
|
||||
@ -1561,7 +1561,7 @@ class ComputeLossAuxOTA:
|
||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||
indices, anch = [], []
|
||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||
|
||||
@ -1614,7 +1614,7 @@ class ComputeLossAuxOTA:
|
||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||
indices, anch = [], []
|
||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user