main code
update loss
This commit is contained in:
parent
ef4dde425b
commit
2267955898
@ -1102,12 +1102,20 @@ class ComputeLossBinOTA:
|
|||||||
matching_anchs[i].append(all_anch[layer_idx])
|
matching_anchs[i].append(all_anch[layer_idx])
|
||||||
|
|
||||||
for i in range(nl):
|
for i in range(nl):
|
||||||
|
if matching_targets[i] != []:
|
||||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||||
|
else:
|
||||||
|
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
|
||||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||||
|
|
||||||
@ -1403,12 +1411,20 @@ class ComputeLossAuxOTA:
|
|||||||
matching_anchs[i].append(all_anch[layer_idx])
|
matching_anchs[i].append(all_anch[layer_idx])
|
||||||
|
|
||||||
for i in range(nl):
|
for i in range(nl):
|
||||||
|
if matching_targets[i] != []:
|
||||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||||
|
else:
|
||||||
|
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
|
||||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||||
|
|
||||||
@ -1548,12 +1564,20 @@ class ComputeLossAuxOTA:
|
|||||||
matching_anchs[i].append(all_anch[layer_idx])
|
matching_anchs[i].append(all_anch[layer_idx])
|
||||||
|
|
||||||
for i in range(nl):
|
for i in range(nl):
|
||||||
|
if matching_targets[i] != []:
|
||||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||||
|
else:
|
||||||
|
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||||
|
|
||||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user