new find_modules() fcn

This commit is contained in:
Glenn Jocher
2020-04-13 17:48:30 -07:00
parent 77e6bdd3c1
commit b8574add37
2 changed files with 6 additions and 5 deletions
+5
View File
@@ -59,6 +59,11 @@ def initialize_weights(model):
m.momentum = 0.03
def find_modules(model, mclass=nn.Conv2d):
# finds layer indices matching module class 'mclass'
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():