This commit is contained in:
Glenn Jocher
2019-02-10 22:01:53 +01:00
parent 62761cffe6
commit 22dc8c0ea6
3 changed files with 16 additions and 102 deletions
+2 -16
View File
@@ -214,7 +214,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
return inter_area / (b1_area + b2_area - inter_area + 1e-16)
def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG, batch_report):
def build_targets(target, anchor_wh, nA, nC, nG):
"""
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
"""
@@ -226,9 +226,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
th = torch.zeros(nB, nA, nG, nG)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
TP = torch.ByteTensor(nB, max(nT)).fill_(0)
FP = torch.ByteTensor(nB, max(nT)).fill_(0)
FN = torch.ByteTensor(nB, max(nT)).fill_(0)
TC = torch.ShortTensor(nB, max(nT)).fill_(-1) # target category
for b in range(nB):
@@ -293,18 +290,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
tcls[b, a, gj, gi, tc] = 1
tconf[b, a, gj, gi] = 1
if batch_report:
# predicted classes and confidence
tb = torch.cat((gx - gw / 2, gy - gh / 2, gx + gw / 2, gy + gh / 2)).view(4, -1).t() # target boxes
pcls = torch.argmax(pred_cls[b, a, gj, gi], 1).cpu()
pconf = torch.sigmoid(pred_conf[b, a, gj, gi]).cpu()
iou_pred = bbox_iou(tb, pred_boxes[b, a, gj, gi].cpu())
TP[b, i] = (pconf > 0.5) & (iou_pred > 0.5) & (pcls == tc)
FP[b, i] = (pconf > 0.5) & (TP[b, i] == 0) # coordinates or class are wrong
FN[b, i] = pconf <= 0.5 # confidence score is too low (set to zero)
return tx, ty, tw, th, tconf, tcls, TP, FP, FN, TC
return tx, ty, tw, th, tconf, tcls
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):