From 2267955898f54d476e83db8c7f1df14d54636a22 Mon Sep 17 00:00:00 2001 From: "Kin-Yiu, Wong" <102582011@cc.ncu.edu.tw> Date: Fri, 15 Jul 2022 13:43:12 +0800 Subject: [PATCH] main code update loss --- utils/loss.py | 60 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 6504301a..17d195f3 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -1102,12 +1102,20 @@ class ComputeLossBinOTA: matching_anchs[i].append(all_anch[layer_idx]) for i in range(nl): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + if matching_targets[i] != []: + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + else: + matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs @@ -1403,12 +1411,20 @@ class ComputeLossAuxOTA: matching_anchs[i].append(all_anch[layer_idx]) for i in range(nl): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + if matching_targets[i] != []: + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + else: + matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs @@ -1548,12 +1564,20 @@ class ComputeLossAuxOTA: matching_anchs[i].append(all_anch[layer_idx]) for i in range(nl): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + if matching_targets[i] != []: + matching_bs[i] = torch.cat(matching_bs[i], dim=0) + matching_as[i] = torch.cat(matching_as[i], dim=0) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) + else: + matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) + matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64) return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs