new find_modules() fcn
This commit is contained in:
@@ -59,6 +59,11 @@ def initialize_weights(model):
|
||||
m.momentum = 0.03
|
||||
|
||||
|
||||
def find_modules(model, mclass=nn.Conv2d):
|
||||
# finds layer indices matching module class 'mclass'
|
||||
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user