From 84ad6080ae1471b3835e60e3ecb1f6a95e9394d4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 3 Jan 2021 14:37:22 -0800 Subject: [PATCH] Update Torch CUDA Synchronize (#1637) --- utils/torch_utils.py | 91 ++++++++++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index cde934af..69a31213 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -13,6 +13,10 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +try: + import thop # for FLOPS computation +except ImportError: + thop = None logger = logging.getLogger(__name__) @@ -32,44 +36,83 @@ def init_torch_seeds(seed=0): # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html torch.manual_seed(seed) if seed == 0: # slower, more reproducible - cudnn.deterministic = True - cudnn.benchmark = False + cudnn.benchmark, cudnn.deterministic = False, True else: # faster, less reproducible - cudnn.deterministic = False - cudnn.benchmark = True + cudnn.benchmark, cudnn.deterministic = True, False def select_device(device='', batch_size=None): # device = 'cpu' or '0' or '0,1,2,3' - cpu_request = device.lower() == 'cpu' - if device and not cpu_request: # if device requested other than 'cpu' + s = f'Using torch {torch.__version__} ' # string + cpu = device.lower() == 'cpu' + if cpu: + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False + elif device: # non-cpu device requested os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity + assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability - cuda = False if cpu_request else torch.cuda.is_available() + cuda = torch.cuda.is_available() and not cpu if cuda: - c = 1024 ** 2 # bytes to MB - ng = torch.cuda.device_count() - if ng > 1 and batch_size: # check that batch_size is compatible with device_count - assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng) - x = [torch.cuda.get_device_properties(i) for i in range(ng)] - s = f'Using torch {torch.__version__} ' - for i in range(0, ng): - if i == 1: - s = ' ' * len(s) - logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c)) + n = torch.cuda.device_count() + if n > 1 and batch_size: # check that batch_size is compatible with device_count + assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' + space = ' ' * len(s) + for i, d in enumerate(device.split(',') if device else range(n)): + p = torch.cuda.get_device_properties(i) + s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB else: - logger.info(f'Using torch {torch.__version__} CPU') + s += 'CPU' - logger.info('') # skip a line + logger.info(f'{s}\n') # skip a line return torch.device('cuda:0' if cuda else 'cpu') def time_synchronized(): - torch.cuda.synchronize() if torch.cuda.is_available() else None + # pytorch-accurate time + if torch.cuda.is_available(): + torch.cuda.synchronize() return time.time() +def profile(x, ops, n=100, device=None): + # profile a pytorch module or list of modules. Example usage: + # x = torch.randn(16, 3, 640, 640) # input + # m1 = lambda x: x * torch.sigmoid(x) + # m2 = nn.SiLU() + # profile(x, [m1, m2], n=100) # profile speed over 100 iterations + + device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + x = x.to(device) + x.requires_grad = True + print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') + print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, 'to') else m # device + m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type + dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward + try: + flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS + except: + flops = 0 + + for _ in range(n): + t[0] = time_synchronized() + y = m(x) + t[1] = time_synchronized() + try: + _ = y.sum().backward() + t[2] = time_synchronized() + except: # no backward method + t[2] = float('nan') + dtf += (t[1] - t[0]) * 1000 / n # ms per op forward + dtb += (t[2] - t[1]) * 1000 / n # ms per op backward + + s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' + s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' + p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters + print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') + + def is_parallel(model): return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) @@ -153,10 +196,10 @@ def model_info(model, verbose=False, img_size=640): try: # FLOPS from thop import profile stride = int(model.stride.max()) if hasattr(model, 'stride') else 32 - img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input - flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS + img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input + flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float - fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS + fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS except (ImportError, Exception): fs = ''