main code
fix gain for train_aux
This commit is contained in:
parent
4cebf4000a
commit
ef4dde425b
@ -1115,7 +1115,7 @@ class ComputeLossBinOTA:
|
|||||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||||
indices, anch = [], []
|
indices, anch = [], []
|
||||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||||
|
|
||||||
@ -1561,7 +1561,7 @@ class ComputeLossAuxOTA:
|
|||||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||||
indices, anch = [], []
|
indices, anch = [], []
|
||||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||||
|
|
||||||
@ -1614,7 +1614,7 @@ class ComputeLossAuxOTA:
|
|||||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
||||||
indices, anch = [], []
|
indices, anch = [], []
|
||||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
|
||||||
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
||||||
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user