hyperparameter updates
This commit is contained in:
@@ -26,3 +26,32 @@ def select_device(force_cpu=False):
|
||||
(i, x[i].name, x[i].total_memory / c))
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
with torch.no_grad():
|
||||
# init
|
||||
fusedconv = torch.nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
||||
|
||||
# prepare spatial bias
|
||||
if conv.bias is not None:
|
||||
b_conv = conv.bias
|
||||
else:
|
||||
b_conv = torch.zeros(conv.weight.size(0))
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fusedconv.bias.copy_(b_conv + b_bn)
|
||||
|
||||
return fusedconv
|
||||
|
||||
+6
-6
@@ -291,7 +291,7 @@ def build_targets(model, targets):
|
||||
|
||||
# iou of targets-anchors
|
||||
t, a = targets, []
|
||||
gwh = targets[:, 4:6] * layer.nG
|
||||
gwh = targets[:, 4:6] * layer.ng
|
||||
if nt:
|
||||
iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
|
||||
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
|
||||
@@ -304,7 +304,7 @@ def build_targets(model, targets):
|
||||
|
||||
# Indices
|
||||
b, c = t[:, :2].long().t() # target image, class
|
||||
gxy = t[:, 2:4] * layer.nG
|
||||
gxy = t[:, 2:4] * layer.ng
|
||||
gi, gj = gxy.long().t() # grid_i, grid_j
|
||||
indices.append((b, a, gj, gi))
|
||||
|
||||
@@ -318,7 +318,7 @@ def build_targets(model, targets):
|
||||
# Class
|
||||
tcls.append(c)
|
||||
if c.shape[0]:
|
||||
assert c.max() <= layer.nC, 'Target classes exceed model classes'
|
||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||
|
||||
return txy, twh, tcls, indices
|
||||
|
||||
@@ -442,12 +442,12 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
||||
|
||||
def coco_class_count(path='../coco/labels/train2014/'):
|
||||
# Histogram of occurrences per class
|
||||
nC = 80 # number classes
|
||||
x = np.zeros(nC, dtype='int32')
|
||||
nc = 80 # number classes
|
||||
x = np.zeros(nc, dtype='int32')
|
||||
files = sorted(glob.glob('%s/*.*' % path))
|
||||
for i, file in enumerate(files):
|
||||
labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
|
||||
x += np.bincount(labels[:, 0].astype('int32'), minlength=nC)
|
||||
x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
|
||||
print(i, len(files))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user