From c24702941f04f973a3a440b8ba6d65a9add481e1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 19 Sep 2019 18:05:04 +0200 Subject: [PATCH] updates --- models.py | 54 ++++++++++++++++++++++++++++++++++-------------------- test.py | 1 + train.py | 1 + 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/models.py b/models.py index 054f2d09..4933f178 100755 --- a/models.py +++ b/models.py @@ -291,27 +291,9 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float def load_darknet_weights(self, weights, cutoff=-1): # Parses and loads the weights stored in 'weights' - # cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved) + + # Establish cutoffs (load layers between 0 and cutoff. if cutoff = -1 all are loaded) file = Path(weights).name - - # Try to download weights if not available locally - msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI' - if not os.path.isfile(weights): - if file == 'yolov3-spp.weights': - gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights) - elif file == 'darknet53.conv.74': - gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights) - else: - try: # download from pjreddie.com - url = 'https://pjreddie.com/media/files/' + file - print('Downloading ' + url) - os.system('curl -f ' + url + ' -o ' + weights) - except IOError: - print(msg) - os.system('rm ' + weights) # remove partial downloads - assert os.path.exists(weights), msg # download missing weights from Google Drive - - # Establish cutoffs if file == 'darknet53.conv.74': cutoff = 75 elif file == 'yolov3-tiny.conv.15': @@ -417,3 +399,35 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): else: print('Error: extension not supported.') + + +def attempt_download(weights): + # Attempt to download pretrained weights if not found locally + + msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI' + if not os.path.isfile(weights): + file = Path(weights).name + + if file == 'yolov3-spp.weights': + gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights) + elif file == 'yolov3-spp.pt': + gdrive_download(id='1vFlbJ_dXPvtwaLLOu-twnjK4exdFiQ73', name=weights) + elif file == 'yolov3.pt': + gdrive_download(id='11uy0ybbOXA2hc-NJkJbbbkDwNX1QZDlz', name=weights) + elif file == 'yolov3-tiny.pt': + gdrive_download(id='1qKSgejNeNczgNNiCn9ZF_o55GFk1DjY_', name=weights) + elif file == 'darknet53.conv.74': + gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights) + elif file == 'yolov3-tiny.conv.15': + gdrive_download(id='140PnSedCsGGgu3rOD6Ez4oI6cdDzerLC', name=weights) + + else: + try: # download from pjreddie.com + url = 'https://pjreddie.com/media/files/' + file + print('Downloading ' + url) + os.system('curl -f ' + url + ' -o ' + weights) + except IOError: + print(msg) + os.system('rm ' + weights) # remove partial downloads + + assert os.path.exists(weights), msg # download missing weights from Google Drive diff --git a/test.py b/test.py index 3f0b115a..100929a8 100644 --- a/test.py +++ b/test.py @@ -27,6 +27,7 @@ def test(cfg, model = Darknet(cfg, img_size).to(device) # Load weights + attempt_download(weights) if weights.endswith('.pt'): # pytorch format model.load_state_dict(torch.load(weights, map_location=device)['model']) else: # darknet format diff --git a/train.py b/train.py index ae31349b..aaef96c8 100644 --- a/train.py +++ b/train.py @@ -100,6 +100,7 @@ def train(): cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 best_fitness = 0. + attempt_download(weights) if weights.endswith('.pt'): # pytorch format # possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. if opt.bucket: