EMA implemented by default

This commit is contained in:
Glenn Jocher
2020-03-29 13:14:54 -07:00
parent dc8e56b9f3
commit 9c5e76b93d
2 changed files with 12 additions and 9 deletions
+6 -3
View File
@@ -1,3 +1,4 @@
import math
import os
import time
from copy import deepcopy
@@ -139,11 +140,12 @@ class ModelEMA:
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9998, device=''):
def __init__(self, model, decay=0.9999, device=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.updates = 0 # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs)
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
@@ -151,7 +153,8 @@ class ModelEMA:
p.requires_grad_(False)
def update(self, model):
d = self.decay
self.updates += 1
d = self.decay(self.updates)
with torch.no_grad():
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
msd, esd = model.module.state_dict(), self.ema.module.state_dict()