main code

update loss
This commit is contained in:
Kin-Yiu, Wong 2022-07-15 13:43:12 +08:00 committed by GitHub
parent ef4dde425b
commit 2267955898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1102,12 +1102,20 @@ class ComputeLossBinOTA:
matching_anchs[i].append(all_anch[layer_idx]) matching_anchs[i].append(all_anch[layer_idx])
for i in range(nl): for i in range(nl):
matching_bs[i] = torch.cat(matching_bs[i], dim=0) if matching_targets[i] != []:
matching_as[i] = torch.cat(matching_as[i], dim=0) matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0) matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0) matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
else:
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
@ -1403,12 +1411,20 @@ class ComputeLossAuxOTA:
matching_anchs[i].append(all_anch[layer_idx]) matching_anchs[i].append(all_anch[layer_idx])
for i in range(nl): for i in range(nl):
matching_bs[i] = torch.cat(matching_bs[i], dim=0) if matching_targets[i] != []:
matching_as[i] = torch.cat(matching_as[i], dim=0) matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0) matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0) matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
else:
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
@ -1548,12 +1564,20 @@ class ComputeLossAuxOTA:
matching_anchs[i].append(all_anch[layer_idx]) matching_anchs[i].append(all_anch[layer_idx])
for i in range(nl): for i in range(nl):
matching_bs[i] = torch.cat(matching_bs[i], dim=0) if matching_targets[i] != []:
matching_as[i] = torch.cat(matching_as[i], dim=0) matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0) matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0) matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
else:
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs