EMA implemented by default
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
@@ -139,11 +140,12 @@ class ModelEMA:
|
||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9998, device=''):
|
||||
def __init__(self, model, decay=0.9999, device=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.updates = 0 # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs)
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
@@ -151,7 +153,8 @@ class ModelEMA:
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
d = self.decay
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
with torch.no_grad():
|
||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
|
||||
Reference in New Issue
Block a user