xy and wh losses respectively merged
This commit is contained in:
+13
-15
@@ -220,10 +220,8 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||
"""
|
||||
nB = len(target) # number of images in batch
|
||||
nT = [len(x) for x in target]
|
||||
tx = torch.zeros(nB, nA, nG, nG) # batch size, anchors, grid size
|
||||
ty = torch.zeros(nB, nA, nG, nG)
|
||||
tw = torch.zeros(nB, nA, nG, nG)
|
||||
th = torch.zeros(nB, nA, nG, nG)
|
||||
txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
|
||||
twh = torch.zeros(nB, nA, nG, nG, 2)
|
||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||
|
||||
@@ -274,22 +272,22 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||
|
||||
# Coordinates
|
||||
tx[b, a, gj, gi] = gx - gi.float()
|
||||
ty[b, a, gj, gi] = gy - gj.float()
|
||||
txy[b, a, gj, gi, 0] = gx - gi.float()
|
||||
txy[b, a, gj, gi, 1] = gy - gj.float()
|
||||
|
||||
# Width and height (yolo method)
|
||||
tw[b, a, gj, gi] = torch.log(gw / anchor_wh[a, 0])
|
||||
th[b, a, gj, gi] = torch.log(gh / anchor_wh[a, 1])
|
||||
twh[b, a, gj, gi, 0] = torch.log(gw / anchor_wh[a, 0])
|
||||
twh[b, a, gj, gi, 1] = torch.log(gh / anchor_wh[a, 1])
|
||||
|
||||
# Width and height (power method)
|
||||
# tw[b, a, gj, gi] = torch.sqrt(gw / anchor_wh[a, 0]) / 2
|
||||
# th[b, a, gj, gi] = torch.sqrt(gh / anchor_wh[a, 1]) / 2
|
||||
# twh[b, a, gj, gi, 0] = torch.sqrt(gw / anchor_wh[a, 0]) / 2
|
||||
# twh[b, a, gj, gi, 1] = torch.sqrt(gh / anchor_wh[a, 1]) / 2
|
||||
|
||||
# One-hot encoding of label
|
||||
tcls[b, a, gj, gi, tc] = 1
|
||||
tconf[b, a, gj, gi] = 1
|
||||
|
||||
return tx, ty, tw, th, tconf, tcls
|
||||
return txy, twh, tconf, tcls
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||
@@ -447,13 +445,13 @@ def plot_results():
|
||||
# import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt')
|
||||
|
||||
plt.figure(figsize=(16, 8))
|
||||
s = ['X', 'Y', 'Width', 'Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision']
|
||||
s = ['XY', 'Width-Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision']
|
||||
files = sorted(glob.glob('results*.txt'))
|
||||
for f in files:
|
||||
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 11, 12, 13]).T # column 13 is mAP
|
||||
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 11, 12, 13]).T # column 11 is mAP
|
||||
n = results.shape[1]
|
||||
for i in range(10):
|
||||
plt.subplot(2, 5, i + 1)
|
||||
for i in range(8):
|
||||
plt.subplot(2, 4, i + 1)
|
||||
plt.plot(range(1, n), results[i, 1:], marker='.', label=f)
|
||||
plt.title(s[i])
|
||||
if i == 0:
|
||||
|
||||
Reference in New Issue
Block a user