diff --git a/train.py b/train.py index 06878e6c..f4ce729c 100644 --- a/train.py +++ b/train.py @@ -97,6 +97,12 @@ def train( collate_fn=dataset.collate_fn, sampler=sampler) + # Mixed precision training https://github.com/NVIDIA/apex + mixed_precision = False + if mixed_precision: + from apex import amp + model, optimizer = amp.initialize(model, optimizer, opt_level='01') + # Start training t = time.time() model_info(model) @@ -145,7 +151,11 @@ def train( loss, loss_dict = compute_loss(pred, target_list) # Compute gradient - loss.backward() + if mixed_precision: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() # Accumulate gradient for x batches before optimizing if (i + 1) % accumulate == 0 or (i + 1) == nB: