updates
This commit is contained in:
@@ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn):
|
||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
with torch.no_grad():
|
||||
# init
|
||||
fusedconv = torch.nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
bias=True
|
||||
)
|
||||
fusedconv = torch.nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
bias=True)
|
||||
|
||||
# prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
|
||||
Reference in New Issue
Block a user