From f501a0fc9dff39bb504b29b6e213ba11583287a8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 22 Jun 2019 15:50:04 +0200 Subject: [PATCH] updates --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 666da626..b1f0489e 100644 --- a/train.py +++ b/train.py @@ -168,11 +168,12 @@ def train( collate_fn=dataset.collate_fn) # Mixed precision training https://github.com/NVIDIA/apex - # install help: https://github.com/NVIDIA/apex/issues/259 - mixed_precision = False - if mixed_precision: + try: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + mixed_precision = True + except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259 + mixed_precision = False # Start training model.hyp = hyp # attach hyperparameters to model