This commit is contained in:
Glenn Jocher
2019-09-13 16:00:52 +02:00
parent 5452bb7036
commit 4286bba40f
3 changed files with 8 additions and 5 deletions
+5 -3
View File
@@ -1,4 +1,5 @@
import os
import torch
@@ -14,9 +15,10 @@ def init_seeds(seed=0):
torch.backends.cudnn.benchmark = False
def select_device(device=None, force_cpu=False, apex=False):
# Set environment variable if device is specified
if device:
def select_device(device=None, apex=False):
if device == 'cpu':
force_cpu = True
elif device: # Set environment variable if device is specified
os.environ['CUDA_VISIBLE_DEVICES'] = device
# apex if mixed precision training https://github.com/NVIDIA/apex