diff --git a/models.py b/models.py index 9d2538e6..772430f6 100755 --- a/models.py +++ b/models.py @@ -8,6 +8,7 @@ ONNX_EXPORT = False def create_modules(module_defs, img_size): # Constructs module list of layer blocks from module configuration in module_defs + img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary hyperparams = module_defs.pop(0) output_filters = [int(hyperparams['channels'])] module_list = nn.ModuleList() @@ -75,12 +76,13 @@ def create_modules(module_defs, img_size): elif mdef['type'] == 'yolo': yolo_index += 1 - l = mdef['from'] if 'from' in mdef else [] + stride = [32, 16, 8, 4, 2][yolo_index] # P3-P7 stride modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list nc=mdef['classes'], # number of classes img_size=img_size, # (416, 416) yolo_index=yolo_index, # 0, 1, 2... - layers=l) # output layers + layers=mdef['from'] if 'from' in mdef else [], # output layers + stride=stride) # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) try: @@ -110,23 +112,34 @@ def create_modules(module_defs, img_size): class YOLOLayer(nn.Module): - def __init__(self, anchors, nc, img_size, yolo_index, layers): + def __init__(self, anchors, nc, img_size, yolo_index, layers, stride): super(YOLOLayer, self).__init__() self.anchors = torch.Tensor(anchors) self.index = yolo_index # index of this layer in layers self.layers = layers # model output layer indices + self.stride = stride # layer stride self.nl = len(layers) # number of output layers (3) self.na = len(anchors) # number of anchors (3) self.nc = nc # number of classes (80) self.no = nc + 5 # number of outputs (85) - self.nx = 0 # initialize number of x gridpoints - self.ny = 0 # initialize number of y gridpoints + self.nx, self.ny = 0, 0 # initialize number of x, y gridpoints + self.anchor_vec = self.anchors / self.stride + self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2) if ONNX_EXPORT: - stride = [32, 16, 8][yolo_index] # stride of this layer - nx = img_size[1] // stride # number x grid points - ny = img_size[0] // stride # number y grid points - create_grids(self, img_size, (nx, ny)) + self.create_grids((img_size[1] // stride, img_size[0] // stride)) # number x, y grid points + + def create_grids(self, ng=(13, 13), device='cpu'): + self.nx, self.ny = ng # x and y grid size + self.ng = torch.Tensor(ng).to(device) + + # build xy offsets + yv, xv = torch.meshgrid([torch.arange(self.ny), torch.arange(self.nx)]) + self.grid_xy = torch.stack((xv, yv), 2).to(device).view((1, 1, self.ny, self.nx, 2)) + + if self.anchor_vec.device != device: + self.anchor_vec = self.anchor_vec.to(device) + self.anchor_wh = self.anchor_wh.to(device) def forward(self, p, img_size, out): ASFF = False # https://arxiv.org/abs/1911.09516 @@ -135,7 +148,7 @@ class YOLOLayer(nn.Module): p = out[self.layers[i]] bs, _, ny, nx = p.shape # bs, 255, 13, 13 if (self.nx, self.ny) != (nx, ny): - create_grids(self, img_size, (nx, ny), p.device, p.dtype) + self.create_grids((nx, ny), p.device) # outputs and weights # w = F.softmax(p[:, -n:], 1) # normalized weights @@ -154,7 +167,7 @@ class YOLOLayer(nn.Module): else: bs, _, ny, nx = p.shape # bs, 255, 13, 13 if (self.nx, self.ny) != (nx, ny): - create_grids(self, img_size, (nx, ny), p.device, p.dtype) + self.create_grids((nx, ny), p.device) # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh) p = p.view(bs, self.na, self.no, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction @@ -273,23 +286,6 @@ def get_yolo_layers(model): return [i for i, x in enumerate(model.module_defs) if x['type'] == 'yolo'] # [82, 94, 106] for yolov3 -def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32): - nx, ny = ng # x and y grid size - self.img_size = max(img_size) - self.stride = self.img_size / max(ng) - - # build xy offsets - yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) - self.grid_xy = torch.stack((xv, yv), 2).to(device).type(type).view((1, 1, ny, nx, 2)) - - # build wh gains - self.anchor_vec = self.anchors.to(device) / self.stride - self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).type(type) - self.ng = torch.Tensor(ng).to(device) - self.nx = nx - self.ny = ny - - def load_darknet_weights(self, weights, cutoff=-1): # Parses and loads the weights stored in 'weights'