main code
update loss
This commit is contained in:
parent
ef4dde425b
commit
2267955898
@ -1102,12 +1102,20 @@ class ComputeLossBinOTA:
|
||||
matching_anchs[i].append(all_anch[layer_idx])
|
||||
|
||||
for i in range(nl):
|
||||
if matching_targets[i] != []:
|
||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||
else:
|
||||
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
|
||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||
|
||||
@ -1403,12 +1411,20 @@ class ComputeLossAuxOTA:
|
||||
matching_anchs[i].append(all_anch[layer_idx])
|
||||
|
||||
for i in range(nl):
|
||||
if matching_targets[i] != []:
|
||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||
else:
|
||||
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
|
||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||
|
||||
@ -1548,12 +1564,20 @@ class ComputeLossAuxOTA:
|
||||
matching_anchs[i].append(all_anch[layer_idx])
|
||||
|
||||
for i in range(nl):
|
||||
if matching_targets[i] != []:
|
||||
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
|
||||
matching_as[i] = torch.cat(matching_as[i], dim=0)
|
||||
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
|
||||
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
|
||||
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
|
||||
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
|
||||
else:
|
||||
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
|
||||
|
||||
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user