diff --git a/train.py b/train.py index b34c2293..e2ba4b0b 100644 --- a/train.py +++ b/train.py @@ -190,13 +190,19 @@ def train(): pin_memory=True, collate_fn=dataset.collate_fn) - # Start training - nb = len(dataloader) # number of batches - prebias = start_epoch == 0 + # Model parameters model.nc = nc # attach number of classes to model model.arc = opt.arc # attach yolo architecture model.hyp = hyp # attach hyperparameters to model + model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + + # Model EMA + # ema = torch_utils.ModelEMA(model, decay=0.9997) + + # Start training + nb = len(dataloader) # number of batches + prebias = start_epoch == 0 maps = np.zeros(nc) # mAP per class # torch.autograd.set_detect_anomaly(True) results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 941c22a4..c706f7f5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,8 +1,10 @@ import os import time +from copy import deepcopy import torch import torch.backends.cudnn as cudnn +import torch.nn as nn def init_seeds(seed=0): @@ -101,3 +103,48 @@ def load_classifier(name='resnet101', n=2): model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) model.last_linear.out_features = n return model + + +class ModelEMA: + """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models + Keep a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. + """ + + def __init__(self, model, decay=0.9998, device=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + d = self.decay + with torch.no_grad(): + if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): + msd, esd = model.module.state_dict(), self.ema.module.state_dict() + else: + msd, esd = model.state_dict(), self.ema.state_dict() + + # self.ema.load_state_dict( + # {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point}) + for k in msd.keys(): + if esd[k].dtype.is_floating_point: + esd[k] *= d + esd[k] += (1. - d) * msd[k].detach()