add MixConv2d() layer
This commit is contained in:
@@ -35,6 +35,35 @@ class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
|
||||
return x
|
||||
|
||||
|
||||
class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels https://arxiv.org/abs/1907.09595
|
||||
def __init__(self, in_ch, out_ch, k=(3, 5, 7), stride=1, dilation=1, bias=True, method='equal_params'):
|
||||
super(MixConv2d, self).__init__()
|
||||
|
||||
groups = len(k)
|
||||
if method == 'equal_ch': # equal channels per group
|
||||
i = torch.linspace(0, groups - 1E-6, out_ch).floor() # out_ch indices
|
||||
ch = [(i == g).sum() for g in range(groups)]
|
||||
else: # 'equal_params': equal parameter count per group
|
||||
b = [out_ch] + [0] * groups
|
||||
a = np.eye(groups + 1, groups, k=-1)
|
||||
a -= np.roll(a, 1, axis=1)
|
||||
a *= np.array(k) ** 2
|
||||
a[0] = 1
|
||||
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
|
||||
|
||||
self.m = nn.ModuleList([torch.nn.Conv2d(in_channels=in_ch,
|
||||
out_channels=ch[g],
|
||||
kernel_size=k[g],
|
||||
stride=stride,
|
||||
padding=(k[g] - 1) // 2, # 'same' pad
|
||||
dilation=dilation,
|
||||
bias=bias) for g in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat([m(x) for m in self.m], 1)
|
||||
|
||||
|
||||
# Activation functions below -------------------------------------------------------------------------------------------
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
|
||||
@@ -27,7 +27,7 @@ def parse_model_cfg(path):
|
||||
|
||||
if key == 'anchors': # return nparray
|
||||
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
|
||||
elif key in ['from', 'layers', 'mask']: # return array
|
||||
elif (key in ['from', 'layers', 'mask']) or (key == 'size' and ',' in val): # return array
|
||||
mdefs[-1][key] = [int(x) for x in val.split(',')]
|
||||
else:
|
||||
val = val.strip()
|
||||
|
||||
Reference in New Issue
Block a user