Extract seed and cuda initialization utils
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_cuda():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
CUDA_AVAILABLE = check_cuda()
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
torch.manual_seed(seed)
|
||||
if CUDA_AVAILABLE:
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def select_device(force_cpu=False):
|
||||
if force_cpu:
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
|
||||
return device
|
||||
@@ -5,11 +5,19 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils import torch_utils
|
||||
|
||||
# Set printoptions
|
||||
torch.set_printoptions(linewidth=1320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch_utils.init_seeds(seed=seed)
|
||||
|
||||
|
||||
def load_classes(path):
|
||||
"""
|
||||
Loads class labels at 'path'
|
||||
|
||||
Reference in New Issue
Block a user