diff --git a/train.py b/train.py index 24a2f7ab..4208b4cd 100644 --- a/train.py +++ b/train.py @@ -211,13 +211,13 @@ def train(): torch_utils.model_info(model, report='summary') # 'full' or 'summary' print('Using %g dataloader workers' % nw) print('Starting training for %g epochs...' % epochs) - for epoch in range(start_epoch, epochs): # epoch ------------------------------ + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() # Prebias if prebias: if epoch < 3: # prebias - ps = 0.1, 0.9 # prebias settings (lr=0.1, momentum=0.9) + ps = np.interp(epoch, [0, 3], [0.1, hyp['lr0']]), 0.0 # prebias settings (lr=0.1, momentum=0.0) else: # normal training ps = hyp['lr0'], hyp['momentum'] # normal training settings print_model_biases(model)