Update Torch CUDA Synchronize (#1637)
This commit is contained in:
parent
7d9535f80e
commit
84ad6080ae
@ -13,6 +13,10 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
try:
|
||||||
|
import thop # for FLOPS computation
|
||||||
|
except ImportError:
|
||||||
|
thop = None
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -32,44 +36,83 @@ def init_torch_seeds(seed=0):
|
|||||||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if seed == 0: # slower, more reproducible
|
if seed == 0: # slower, more reproducible
|
||||||
cudnn.deterministic = True
|
cudnn.benchmark, cudnn.deterministic = False, True
|
||||||
cudnn.benchmark = False
|
|
||||||
else: # faster, less reproducible
|
else: # faster, less reproducible
|
||||||
cudnn.deterministic = False
|
cudnn.benchmark, cudnn.deterministic = True, False
|
||||||
cudnn.benchmark = True
|
|
||||||
|
|
||||||
|
|
||||||
def select_device(device='', batch_size=None):
|
def select_device(device='', batch_size=None):
|
||||||
# device = 'cpu' or '0' or '0,1,2,3'
|
# device = 'cpu' or '0' or '0,1,2,3'
|
||||||
cpu_request = device.lower() == 'cpu'
|
s = f'Using torch {torch.__version__} ' # string
|
||||||
if device and not cpu_request: # if device requested other than 'cpu'
|
cpu = device.lower() == 'cpu'
|
||||||
|
if cpu:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||||
|
elif device: # non-cpu device requested
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
||||||
assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
|
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
||||||
|
|
||||||
cuda = False if cpu_request else torch.cuda.is_available()
|
cuda = torch.cuda.is_available() and not cpu
|
||||||
if cuda:
|
if cuda:
|
||||||
c = 1024 ** 2 # bytes to MB
|
n = torch.cuda.device_count()
|
||||||
ng = torch.cuda.device_count()
|
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|
||||||
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
||||||
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
|
space = ' ' * len(s)
|
||||||
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
for i, d in enumerate(device.split(',') if device else range(n)):
|
||||||
s = f'Using torch {torch.__version__} '
|
p = torch.cuda.get_device_properties(i)
|
||||||
for i in range(0, ng):
|
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
||||||
if i == 1:
|
|
||||||
s = ' ' * len(s)
|
|
||||||
logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c))
|
|
||||||
else:
|
else:
|
||||||
logger.info(f'Using torch {torch.__version__} CPU')
|
s += 'CPU'
|
||||||
|
|
||||||
logger.info('') # skip a line
|
logger.info(f'{s}\n') # skip a line
|
||||||
return torch.device('cuda:0' if cuda else 'cpu')
|
return torch.device('cuda:0' if cuda else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
def time_synchronized():
|
def time_synchronized():
|
||||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
# pytorch-accurate time
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def profile(x, ops, n=100, device=None):
|
||||||
|
# profile a pytorch module or list of modules. Example usage:
|
||||||
|
# x = torch.randn(16, 3, 640, 640) # input
|
||||||
|
# m1 = lambda x: x * torch.sigmoid(x)
|
||||||
|
# m2 = nn.SiLU()
|
||||||
|
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
||||||
|
|
||||||
|
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
x = x.to(device)
|
||||||
|
x.requires_grad = True
|
||||||
|
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
||||||
|
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
||||||
|
for m in ops if isinstance(ops, list) else [ops]:
|
||||||
|
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||||
|
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
|
||||||
|
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
||||||
|
try:
|
||||||
|
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|
||||||
|
except:
|
||||||
|
flops = 0
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
t[0] = time_synchronized()
|
||||||
|
y = m(x)
|
||||||
|
t[1] = time_synchronized()
|
||||||
|
try:
|
||||||
|
_ = y.sum().backward()
|
||||||
|
t[2] = time_synchronized()
|
||||||
|
except: # no backward method
|
||||||
|
t[2] = float('nan')
|
||||||
|
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||||
|
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||||
|
|
||||||
|
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
||||||
|
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
||||||
|
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
||||||
|
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||||
|
|
||||||
|
|
||||||
def is_parallel(model):
|
def is_parallel(model):
|
||||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
|
|
||||||
@ -153,10 +196,10 @@ def model_info(model, verbose=False, img_size=640):
|
|||||||
try: # FLOPS
|
try: # FLOPS
|
||||||
from thop import profile
|
from thop import profile
|
||||||
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
||||||
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
||||||
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
|
||||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
||||||
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
|
||||||
except (ImportError, Exception):
|
except (ImportError, Exception):
|
||||||
fs = ''
|
fs = ''
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user