diff --git a/train.py b/train.py index 79dabb92..39ad63ad 100644 --- a/train.py +++ b/train.py @@ -70,10 +70,10 @@ def train( del chkpt else: # Initialize model with backbone (optional) - if cfg.endswith('yolov3.cfg'): - cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74') - elif cfg.endswith('yolov3-tiny.cfg'): + if '-tiny.cfg' in cfg: cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') + else: + cutoff = load_darknet_weights(model, weights + 'darknet53.conv.74') # Set scheduler (reduce lr at epoch 250) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1, last_epoch=start_epoch - 1)