add FeatureConcat() module

This commit is contained in:
Glenn Jocher
2020-04-05 14:47:41 -07:00
parent 968b2ec004
commit a657345b45
2 changed files with 13 additions and 17 deletions
+10
View File
@@ -3,6 +3,16 @@ import torch.nn.functional as F
from utils.utils import *
class FeatureConcat(nn.Module):
def __init__(self, layers):
super(FeatureConcat, self).__init__()
self.layers = layers # layer indices
self.multiple = len(layers) > 1 # multiple layers flag
def forward(self, x, outputs):
return torch.cat([outputs[i] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]]
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, layers, weight=False):
super(WeightedFeatureFusion, self).__init__()