new layers.py file
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.utils import *
|
||||
|
||||
|
||||
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||
def __init__(self, layers, weight=False):
|
||||
super(WeightedFeatureFusion, self).__init__()
|
||||
self.layers = layers # layer indices
|
||||
self.weight = weight # apply weights boolean
|
||||
self.n = len(layers) + 1 # number of layers
|
||||
if weight:
|
||||
self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
|
||||
|
||||
def forward(self, x, outputs):
|
||||
# Weights
|
||||
if self.weight:
|
||||
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
||||
x = x * w[0]
|
||||
|
||||
# Fusion
|
||||
nx = x.shape[1] # input channels
|
||||
for i in range(self.n - 1):
|
||||
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
|
||||
na = a.shape[1] # feature channels
|
||||
|
||||
# Adjust channels
|
||||
if nx == na: # same shape
|
||||
x = x + a
|
||||
elif nx > na: # slice input
|
||||
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
||||
else: # slice feature
|
||||
x = x + a[:, :nx]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
ctx.save_for_backward(i)
|
||||
return i * torch.sigmoid(i)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
||||
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
def forward(self, x):
|
||||
return SwishImplementation.apply(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.mul_(torch.sigmoid(x))
|
||||
|
||||
|
||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||
def forward(self, x):
|
||||
return x.mul_(F.softplus(x).tanh())
|
||||
Reference in New Issue
Block a user