From adc49abc711678d01b54c724eee1168fd9d62811 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 6 Dec 2020 11:55:27 +0100 Subject: [PATCH] Implement default class names (#1592) --- models/yolo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index 8978fb95..c388fb2d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,16 +1,16 @@ import argparse import logging -import math import sys from copy import deepcopy from pathlib import Path -sys.path.append('./') # to run '$ python *.py' files in subdirectories -logger = logging.getLogger(__name__) - +import math import torch import torch.nn as nn +sys.path.append('./') # to run '$ python *.py' files in subdirectories +logger = logging.getLogger(__name__) + from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape from models.experimental import MixConv2d, CrossConv, C3 from utils.autoanchor import check_anchor_order @@ -82,6 +82,7 @@ class Model(nn.Module): logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) self.yaml['nc'] = nc # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out + self.names = [str(i) for i in range(self.yaml['nc'])] # default names # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # Build strides, anchors