EMA class updates

This commit is contained in:
Glenn Jocher
2020-03-14 16:23:14 -07:00
parent 666ba85ed3
commit b89cc396af
2 changed files with 12 additions and 9 deletions
+10 -6
View File
@@ -141,10 +141,14 @@ class ModelEMA:
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
else:
msd, esd = model.state_dict(), self.ema.state_dict()
# self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
for k, v in esd.items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
# self.ema.load_state_dict(
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
for k in msd.keys():
if esd[k].dtype.is_floating_point:
esd[k] *= d
esd[k] += (1. - d) * msd[k].detach()
def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
if not k.startswith('_'):
self.ema.__setattr__(k, model.getattr(k))