Merge remote-tracking branch 'origin/master'

This commit is contained in:
Glenn Jocher
2019-06-26 11:28:06 +02:00
+3 -4
View File
@@ -15,7 +15,7 @@ def create_modules(module_defs):
hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams['channels'])]
module_list = nn.ModuleList()
yolo_layer_count = 0
for i, module_def in enumerate(module_defs):
modules = nn.Sequential()
@@ -66,9 +66,8 @@ def create_modules(module_defs):
nc = int(module_def['classes']) # number of classes
img_size = hyperparams['height']
# Define detection layer
yolo_layer = YOLOLayer(anchors, nc, img_size, yolo_layer_count, cfg=hyperparams['cfg'])
yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg'])
modules.add_module('yolo_%d' % i, yolo_layer)
yolo_layer_count += 1
# Register module list and number of output filters
module_list.append(modules)
@@ -100,7 +99,7 @@ class Upsample(nn.Module):
class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_layer, cfg):
def __init__(self, anchors, nc, img_size, cfg):
super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors)