diff --git a/train.py b/train.py index f2c3f4f2..0a56a8c7 100644 --- a/train.py +++ b/train.py @@ -261,9 +261,8 @@ def train(): print('WARNING: nan loss detected, ending training') return results - # Divide by accumulation count - if accumulate > 1: - loss /= accumulate + # Scale loss by nominal batch_size of 64 + loss *= batch_size / 64 # Compute gradient if mixed_precision: