From 77c6c01970495a1636dce4090ff57f80c44eb72a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Mar 2020 17:51:40 -0700 Subject: [PATCH] EMA class updates --- utils/torch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 187d5142..ac38249c 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -152,7 +152,7 @@ class ModelEMA: msd, esd = model.module.state_dict(), self.ema.module.state_dict() else: msd, esd = model.state_dict(), self.ema.state_dict() - # self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point}) + for k, v in esd.items(): if v.dtype.is_floating_point: v *= d @@ -162,4 +162,4 @@ class ModelEMA: # Assign attributes (which may change during training) for k in model.__dict__.keys(): if not k.startswith('_'): - setattr(model, k, getattr(model, k)) + setattr(self.ema, k, getattr(model, k))