diff --git a/train.py b/train.py index 6f682609..7d7e4b17 100644 --- a/train.py +++ b/train.py @@ -62,12 +62,13 @@ def train(): epochs = opt.epochs # 500200 batches at bs 16, 117263 images = 273 epochs batch_size = opt.batch_size accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64 + weights = opt.weights # initial training weights # Initialize init_seeds() - weights = 'weights' + os.sep - last = weights + 'last.pt' - best = weights + 'best.pt' + wdir = 'weights' + os.sep # weights dir + last = wdir + 'last.pt' + best = wdir + 'best.pt' device = torch_utils.select_device(apex=mixed_precision) multi_scale = opt.multi_scale @@ -94,26 +95,23 @@ def train(): cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 best_fitness = 0. - nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) - if opt.resume or opt.transfer: # Load previously saved model - if opt.transfer: # Transfer learning - chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device) - model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255}, - strict=False) + if weights.endswith('.pt'): # pytorch format + # possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. + if opt.bucket: + os.system('gsutil cp gs://%s/last.pt %s' % (opt.bucket, last)) # download from bucket + chkpt = torch.load(weights, map_location=device) - for p in model.parameters(): - p.requires_grad = True if p.shape[0] == nf else False - - else: # resume from last.pt - if opt.bucket: - os.system('gsutil cp gs://%s/last.pt %s' % (opt.bucket, last)) # download from bucket - chkpt = torch.load(last, map_location=device) # load checkpoint - model.load_state_dict(chkpt['model']) + # load model + if opt.transfer: + chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} + model.load_state_dict(chkpt['model'], strict=False) + # load optimizer if chkpt['optimizer'] is not None: optimizer.load_state_dict(chkpt['optimizer']) best_fitness = chkpt['best_fitness'] + # load results if chkpt.get('training_results') is not None: with open('results.txt', 'w') as file: file.write(chkpt['training_results']) # write results.txt @@ -121,15 +119,14 @@ def train(): start_epoch = chkpt['epoch'] + 1 del chkpt - else: # Initialize model with backbone (optional) - if '-tiny.cfg' in cfg: - cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') - else: - cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74') + elif weights.endswith('.weights'): # darknet format + # possible weights are 'yolov3.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc. + cutoff = load_darknet_weights(model, weights) - # Remove old results - for f in glob.glob('*_batch*.jpg') + glob.glob('results.txt'): - os.remove(f) + if opt.transfer: # transfer learning + nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) + for p in model.parameters(): + p.requires_grad = True if p.shape[0] == nf else False # Scheduler https://github.com/ultralytics/yolov3/issues/238 # lf = lambda x: 1 - x / epochs # linear ramp to zero @@ -181,6 +178,10 @@ def train(): pin_memory=True, collate_fn=dataset.collate_fn) + # Remove previous results + for f in glob.glob('*_batch*.jpg') + glob.glob('results.txt'): + os.remove(f) + # Start training model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model @@ -327,7 +328,7 @@ def train(): # Save backup every 10 epochs (optional) if epoch > 0 and epoch % 10 == 0: - torch.save(chkpt, weights + 'backup%g.pt' % epoch) + torch.save(chkpt, wdir + 'backup%g.pt' % epoch) # Delete checkpoint del chkpt @@ -345,7 +346,7 @@ if __name__ == '__main__': parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs parser.add_argument('--batch-size', type=int, default=32) # effective bs = batch_size * accumulate = 16 * 4 = 64 parser.add_argument('--accumulate', type=int, default=2, help='batches to accumulate before optimizing') - parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') + parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp-1cls.cfg', help='cfg file path') parser.add_argument('--data', type=str, default='data/coco.data', help='*.data file path') parser.add_argument('--multi-scale', action='store_true', help='adjust (67% - 150%) img_size every 10 batches') parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)') @@ -358,7 +359,9 @@ if __name__ == '__main__': parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--img-weights', action='store_true', help='select training images by weight') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') + parser.add_argument('--weights', type=str, default='', help='initial weights') # i.e. weights/darknet.53.conv.74 opt = parser.parse_args() + opt.weights = 'weights/last.pt' if opt.resume else opt.weights print(opt) tb_writer = None