updates
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
@@ -101,3 +103,48 @@ def load_classifier(name='resnet101', n=2):
|
||||
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||
model.last_linear.out_features = n
|
||||
return model
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||
relative to your update count per epoch.
|
||||
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
process, or after the training stops converging.
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9998, device=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
d = self.decay
|
||||
with torch.no_grad():
|
||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
else:
|
||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||
|
||||
# self.ema.load_state_dict(
|
||||
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||
for k in msd.keys():
|
||||
if esd[k].dtype.is_floating_point:
|
||||
esd[k] *= d
|
||||
esd[k] += (1. - d) * msd[k].detach()
|
||||
|
||||
Reference in New Issue
Block a user