EMA class updates
This commit is contained in:
+10
-6
@@ -141,10 +141,14 @@ class ModelEMA:
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
else:
|
||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||
# self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||
for k, v in esd.items():
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1. - d) * msd[k].detach()
|
||||
|
||||
# self.ema.load_state_dict(
|
||||
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||
for k in msd.keys():
|
||||
if esd[k].dtype.is_floating_point:
|
||||
esd[k] *= d
|
||||
esd[k] += (1. - d) * msd[k].detach()
|
||||
def update_attr(self, model):
|
||||
# Assign attributes (which may change during training)
|
||||
for k in model.__dict__.keys():
|
||||
if not k.startswith('_'):
|
||||
self.ema.__setattr__(k, model.getattr(k))
|
||||
|
||||
Reference in New Issue
Block a user