diff --git a/models.py b/models.py index ff88eef1..a8abe99b 100755 --- a/models.py +++ b/models.py @@ -6,7 +6,7 @@ import torch.nn as nn from utils.parse_config import * from utils.utils import * -ONNX_EXPORT = False +ONNX_EXPORT = True def create_modules(module_defs): @@ -16,6 +16,7 @@ def create_modules(module_defs): hyperparams = module_defs.pop(0) output_filters = [int(hyperparams['channels'])] module_list = nn.ModuleList() + yolo_layer_count = 0 for i, module_def in enumerate(module_defs): modules = nn.Sequential() @@ -63,11 +64,12 @@ def create_modules(module_defs): anchors = [float(x) for x in module_def['anchors'].split(',')] anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] anchors = [anchors[i] for i in anchor_idxs] - num_classes = int(module_def['classes']) - img_height = int(hyperparams['height']) + nC = int(module_def['classes']) # number of classes + img_size = int(hyperparams['height']) # Define detection layer - yolo_layer = YOLOLayer(anchors, num_classes, img_height, anchor_idxs, cfg=hyperparams['cfg']) + yolo_layer = YOLOLayer(anchors, nC, img_size, yolo_layer_count, cfg=hyperparams['cfg']) modules.add_module('yolo_%d' % i, yolo_layer) + yolo_layer_count += 1 # Register module list and number of output filters module_list.append(modules) @@ -100,53 +102,40 @@ class Upsample(nn.Module): class YOLOLayer(nn.Module): - def __init__(self, anchors, nC, img_dim, anchor_idxs, cfg): + def __init__(self, anchors, nC, img_size, yolo_layer, cfg): + # TODO: img_size from hyperparams in cfg file, NOT from parser. Make dynamic super(YOLOLayer, self).__init__() - anchors = [(a_w, a_h) for a_w, a_h in anchors] # (pixels) nA = len(anchors) - - self.anchors = anchors + self.anchors = torch.FloatTensor(anchors) self.nA = nA # number of anchors (3) self.nC = nC # number of classes (80) - self.bbox_attrs = 5 + nC - self.img_dim = img_dim # TODO: from hyperparams in cfg file, NOT from parser. Make dynamic - self.initialized = False - # self.weights = class_weights() + self.img_size = 0 + # self.coco_class_weights = coco_class_weights() - if anchor_idxs[0] == (nA * 2): # 6 - stride = 32 - elif anchor_idxs[0] == nA: # 3 - stride = 16 + if ONNX_EXPORT: # grids must be computed in __init__ + stride = [32, 16, 8][yolo_layer] # stride of this layer + if cfg.endswith('yolov3-tiny.cfg'): + stride *= 2 + + self.nG = int(img_size / stride) # number grid points + create_grids(self, img_size, self.nG) + + def forward(self, p, img_size, targets=None, var=None): + if ONNX_EXPORT: + bs, nG = 1, self.nG # batch size, grid size else: - stride = 8 + bs, nG = p.shape[0], p.shape[-1] - if cfg.endswith('yolov3-tiny.cfg'): - stride *= 2 + if self.img_size != img_size: + create_grids(self, img_size, nG) - # Build anchor grids - nG = int(self.img_dim / stride) # number grid points - self.nG = nG - self.stride = stride - - self.grid_x = torch.arange(nG).repeat((nG, 1)).view((1, 1, nG, nG)).float() - self.grid_y = torch.arange(nG).repeat((nG, 1)).t().view((1, 1, nG, nG)).float() - self.anchor_wh = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors]) # scale anchors - self.anchor_w = self.anchor_wh[:, 0].view((1, nA, 1, 1)) - self.anchor_h = self.anchor_wh[:, 1].view((1, nA, 1, 1)) - - def forward(self, p, targets=None, var=None): - bs = 1 if ONNX_EXPORT else p.shape[0] # batch size - nG = self.nG if ONNX_EXPORT else p.shape[-1] # number of grid points - - if not self.initialized: - self.initialized = True - if p.is_cuda: - self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() - self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda() + if p.is_cuda: + self.grid_xy = self.grid_xy.cuda() + self.anchor_vector = self.anchor_vector.cuda() # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) - p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction + p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction # Training if targets is not None: @@ -172,7 +161,7 @@ class YOLOLayer(nn.Module): # width = ((w.data * 2) ** 2) * self.anchor_w # height = ((h.data * 2) ** 2) * self.anchor_h - tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_wh, self.nA, self.nC, nG) + tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_vector, self.nA, self.nC, nG) tcls = tcls[mask] if x.is_cuda: @@ -204,12 +193,8 @@ class YOLOLayer(nn.Module): else: if ONNX_EXPORT: - anchor_w = self.anchor_w.repeat((1, 1, nG, nG)).view(1, -1, 1) - anchor_h = self.anchor_h.repeat((1, 1, nG, nG)).view(1, -1, 1) - grid_x = self.grid_x.repeat(1, self.nA, 1, 1).view(1, -1, 1) - grid_y = self.grid_y.repeat(1, self.nA, 1, 1).view(1, -1, 1) - grid_xy = torch.cat((grid_x, grid_y), 2) - anchor_wh = torch.cat((anchor_w, anchor_h), 2) / nG + grid_xy = self.grid_xy.repeat((1, self.nA, 1, 1, 1)).view((1, -1, 2)) + anchor_wh = self.anchor_wh.repeat((1, 1, nG, nG, 1)).view((1, -1, 2)) # p = p.view(-1, 85) # xy = torch.sigmoid(p[:, 0:2]) + self.grid_xy[0] # x, y @@ -230,10 +215,8 @@ class YOLOLayer(nn.Module): p_cls = p_cls.permute(2, 1, 0) return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t() - p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x - p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y - p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w # width - p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h # height + p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy + p[..., 2:4] = torch.exp(p[..., 2:4]) * self.anchor_wh # wh p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf p[..., :4] *= self.stride @@ -258,30 +241,30 @@ class Darknet(nn.Module): def forward(self, x, targets=None, var=0): self.losses = defaultdict(float) is_training = targets is not None + img_size = x.shape[-1] layer_outputs = [] output = [] for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)): - if module_def['type'] in ['convolutional', 'upsample', 'maxpool']: + mtype = module_def['type'] + if mtype in ['convolutional', 'upsample', 'maxpool']: x = module(x) - elif module_def['type'] == 'route': + elif mtype == 'route': layer_i = [int(x) for x in module_def['layers'].split(',')] if len(layer_i) == 1: x = layer_outputs[layer_i[0]] else: x = torch.cat([layer_outputs[i] for i in layer_i], 1) - elif module_def['type'] == 'shortcut': + elif mtype == 'shortcut': layer_i = int(module_def['from']) x = layer_outputs[-1] + layer_outputs[layer_i] - elif module_def['type'] == 'yolo': - # Train phase: get loss - if is_training: - x, *losses = module[0](x, targets, var) + elif mtype == 'yolo': + if is_training: # get loss + x, *losses = module[0](x, img_size, targets, var) for name, loss in zip(self.loss_names, losses): self.losses[name] += loss - # Test phase: Get detections - else: - x = module(x) + else: # get detections + x = module[0](x, img_size) output.append(x) layer_outputs.append(x) @@ -295,6 +278,19 @@ class Darknet(nn.Module): return sum(output) if is_training else torch.cat(output, 1) +def create_grids(self, img_size, nG): + self.stride = img_size / nG + + # build xy offsets + grid_x = torch.arange(nG).repeat((nG, 1)).view((1, 1, nG, nG)).float() + grid_y = grid_x.permute(0, 1, 3, 2) + self.grid_xy = torch.stack((grid_x, grid_y), 4) + + # build wh gains + self.anchor_vec = self.anchors / self.stride + self.anchor_wh = self.anchor_vec.view(1, self.nA, 1, 1, 2) + + def load_darknet_weights(self, weights, cutoff=-1): # Parses and loads the weights stored in 'weights' # cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved) diff --git a/utils/utils.py b/utils/utils.py index b0b438dc..a7a61ff1 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -38,7 +38,7 @@ def model_info(model): # Plots a line-by-line description of a PyTorch model print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g)) -def class_weights(): # frequency of each class in coco train2014 +def coco_class_weights(): # frequency of each class in coco train2014 weights = 1 / torch.FloatTensor( [187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671, 6769, 5706, 3908, 903, 3686, 3596, 6200, 7920, 8779, 4505, 4272, 1862, 4698, 1962, 4403, 6659, 2402, 2689,