updates
This commit is contained in:
@@ -168,11 +168,12 @@ def train(
|
|||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
# Mixed precision training https://github.com/NVIDIA/apex
|
# Mixed precision training https://github.com/NVIDIA/apex
|
||||||
# install help: https://github.com/NVIDIA/apex/issues/259
|
try:
|
||||||
mixed_precision = False
|
|
||||||
if mixed_precision:
|
|
||||||
from apex import amp
|
from apex import amp
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
|
mixed_precision = True
|
||||||
|
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
|
||||||
|
mixed_precision = False
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
|
|||||||
Reference in New Issue
Block a user