kaiming weight init

This commit is contained in:
Glenn Jocher
2020-04-11 10:45:33 -07:00
parent 2cf23c4aee
commit 58edfc4a84
3 changed files with 10 additions and 9 deletions
+9
View File
@@ -50,6 +50,15 @@ def time_synchronized():
return time.time()
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
-9
View File
@@ -93,15 +93,6 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
return x
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.03)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.03)
torch.nn.init.constant_(m.bias.data, 0.0)
def xyxy2xywh(x):
# Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h]
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)