diff --git a/models.py b/models.py index 21e871e1..ef417603 100755 --- a/models.py +++ b/models.py @@ -105,8 +105,8 @@ class YOLOLayer(nn.Module): self.nA = len(anchors) # number of anchors (3) self.nC = nC # number of classes (80) self.img_size = 0 - self.nG, self.stride, self.grid_xy, self.anchor_vec, self.anchor_wh = \ - [], [], [], [], [] + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + create_grids(self, 32, 1, device=device) if ONNX_EXPORT: # grids must be computed in __init__ stride = [32, 16, 8][yolo_layer] # stride of this layer