From 798a7396f1851f06f3ab6971b73f48baefff9b0c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 26 Aug 2019 16:24:19 +0200 Subject: [PATCH] weight_decay fix --- train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index f2c3f4f2..0a56a8c7 100644 --- a/train.py +++ b/train.py @@ -261,9 +261,8 @@ def train(): print('WARNING: nan loss detected, ending training') return results - # Divide by accumulation count - if accumulate > 1: - loss /= accumulate + # Scale loss by nominal batch_size of 64 + loss *= batch_size / 64 # Compute gradient if mixed_precision: