Extract seed and cuda initialization utils

This commit is contained in:
Guillermo García
2018-12-05 11:55:27 +01:00
parent 45ee668fd7
commit 5a566454f5
5 changed files with 53 additions and 15 deletions
+23
View File
@@ -0,0 +1,23 @@
import torch
def check_cuda():
return torch.cuda.is_available()
CUDA_AVAILABLE = check_cuda()
def init_seeds(seed=0):
torch.manual_seed(seed)
if CUDA_AVAILABLE:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def select_device(force_cpu=False):
if force_cpu:
device = torch.device('cpu')
else:
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
return device
+8
View File
@@ -5,11 +5,19 @@ import numpy as np
import torch
import torch.nn.functional as F
from utils import torch_utils
# Set printoptions
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
torch_utils.init_seeds(seed=seed)
def load_classes(path):
"""
Loads class labels at 'path'