add FeatureConcat() module
This commit is contained in:
@@ -3,6 +3,16 @@ import torch.nn.functional as F
|
||||
from utils.utils import *
|
||||
|
||||
|
||||
class FeatureConcat(nn.Module):
|
||||
def __init__(self, layers):
|
||||
super(FeatureConcat, self).__init__()
|
||||
self.layers = layers # layer indices
|
||||
self.multiple = len(layers) > 1 # multiple layers flag
|
||||
|
||||
def forward(self, x, outputs):
|
||||
return torch.cat([outputs[i] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]]
|
||||
|
||||
|
||||
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||
def __init__(self, layers, weight=False):
|
||||
super(WeightedFeatureFusion, self).__init__()
|
||||
|
||||
Reference in New Issue
Block a user