This commit is contained in:
Glenn Jocher
2019-08-17 02:14:28 +02:00
parent 9953335cfe
commit 321bd95764
3 changed files with 11 additions and 2 deletions
+8
View File
@@ -324,6 +324,14 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
# udm_ce = torch.zeros_like(pi0[..., 0]).long() # unified detection matrix for CE
# udm_ce[b, a, gj, gi] = tcls[i] + 1
# lcls += (k * h['cls']) * CE(pi0[..., 4:].view(-1, model.nc + 1), udm_ce.view(-1)) # unified CE
# udm = torch.zeros_like(pi0[..., 5:]) # unified detection matrix for BCE
# udm[b, a, gj, gi, tcls[i]] = 1.0
# lcls += (k * h['cls']) * BCEcls(pi0[..., 5:], udm) # unified BCE (hyps 200-30)
# Append targets to text file
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]