greenhouse/utils/torch_utils.py

286 lines
11 KiB
Python
Raw Normal View History

# PyTorch utils
import logging
2020-03-29 13:14:54 -07:00
import math
2019-09-10 14:59:45 +02:00
import os
2020-03-04 10:26:35 -08:00
import time
from contextlib import contextmanager
2020-03-13 20:12:54 -07:00
from copy import deepcopy
2019-09-13 16:00:52 +02:00
import torch
2020-02-27 22:50:26 -08:00
import torch.backends.cudnn as cudnn
2020-03-13 20:12:54 -07:00
import torch.nn as nn
2020-03-15 18:39:54 -07:00
import torch.nn.functional as F
import torchvision
2021-01-03 14:37:22 -08:00
try:
import thop # for FLOPS computation
except ImportError:
thop = None
logger = logging.getLogger(__name__)
2019-09-10 10:56:56 +02:00
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
def init_torch_seeds(seed=0):
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(seed)
if seed == 0: # slower, more reproducible
2021-01-03 14:37:22 -08:00
cudnn.benchmark, cudnn.deterministic = False, True
else: # faster, less reproducible
2021-01-03 14:37:22 -08:00
cudnn.benchmark, cudnn.deterministic = True, False
def select_device(device='', batch_size=None):
2019-09-26 13:52:37 +02:00
# device = 'cpu' or '0' or '0,1,2,3'
2021-01-03 14:37:22 -08:00
s = f'Using torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
2019-09-26 13:52:37 +02:00
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
2021-01-03 14:37:22 -08:00
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
2019-09-26 13:52:37 +02:00
2021-01-03 14:37:22 -08:00
cuda = torch.cuda.is_available() and not cpu
2019-04-08 15:41:14 +02:00
if cuda:
2021-01-03 14:37:22 -08:00
n = torch.cuda.device_count()
if n > 1 and batch_size: # check that batch_size is compatible with device_count
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
space = ' ' * len(s)
for i, d in enumerate(device.split(',') if device else range(n)):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
2019-09-26 13:52:37 +02:00
else:
2021-01-03 14:37:22 -08:00
s += 'CPU'
2019-02-16 14:33:52 +01:00
2021-01-03 14:37:22 -08:00
logger.info(f'{s}\n') # skip a line
2019-09-26 13:52:37 +02:00
return torch.device('cuda:0' if cuda else 'cpu')
2019-04-19 20:41:18 +02:00
2020-03-04 10:26:35 -08:00
def time_synchronized():
2021-01-03 14:37:22 -08:00
# pytorch-accurate time
if torch.cuda.is_available():
torch.cuda.synchronize()
2020-03-04 10:26:35 -08:00
return time.time()
2021-01-03 14:37:22 -08:00
def profile(x, ops, n=100, device=None):
# profile a pytorch module or list of modules. Example usage:
# x = torch.randn(16, 3, 640, 640) # input
# m1 = lambda x: x * torch.sigmoid(x)
# m2 = nn.SiLU()
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m # device
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
try:
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
except:
flops = 0
for _ in range(n):
t[0] = time_synchronized()
y = m(x)
t[1] = time_synchronized()
try:
_ = y.sum().backward()
t[2] = time_synchronized()
except: # no backward method
t[2] = float('nan')
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
def is_parallel(model):
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
2020-04-11 10:45:33 -07:00
def initialize_weights(model):
for m in model.modules():
2020-04-14 01:20:57 -07:00
t = type(m)
if t is nn.Conv2d:
2020-04-14 15:58:32 -07:00
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
2020-04-14 01:20:57 -07:00
elif t is nn.BatchNorm2d:
m.eps = 1e-3
2020-04-12 18:44:18 -07:00
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
2020-04-14 01:20:57 -07:00
m.inplace = True
2020-04-11 10:45:33 -07:00
2020-04-13 17:48:30 -07:00
def find_modules(model, mclass=nn.Conv2d):
# Finds layer indices matching module class 'mclass'
2020-04-13 17:48:30 -07:00
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
def sparsity(model):
# Return global model sparsity
a, b = 0., 0.
for p in model.parameters():
a += p.numel()
b += (p == 0).sum()
return b / a
2019-04-19 20:41:18 +02:00
2019-10-07 00:50:47 +02:00
def prune(model, amount=0.3):
# Prune model to requested global sparsity
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))
2019-10-07 00:50:47 +02:00
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def model_info(model, verbose=False, img_size=640):
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
2019-10-07 00:50:47 +02:00
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
2020-03-14 16:46:54 -07:00
if verbose:
2019-10-07 00:50:47 +02:00
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
2019-10-10 14:40:18 +02:00
2020-04-01 14:05:41 -07:00
try: # FLOPS
from thop import profile
2020-12-05 11:41:17 +01:00
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
2021-01-03 14:37:22 -08:00
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
2021-01-03 14:37:22 -08:00
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
except (ImportError, Exception):
2020-04-01 14:05:41 -07:00
fs = ''
logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
2020-03-19 12:30:07 -07:00
2019-10-10 14:40:18 +02:00
def load_classifier(name='resnet101', n=2):
# Loads a pretrained model reshaped to n-class output
model = torchvision.models.__dict__[name](pretrained=True)
2019-10-10 14:40:18 +02:00
# ResNet model properties
# input_size = [3, 224, 224]
# input_space = 'RGB'
# input_range = [0, 1]
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
2019-10-10 14:40:18 +02:00
# Reshape output to n classes
filters = model.fc.weight.shape[1]
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
2019-10-10 14:40:18 +02:00
return model
2020-03-13 20:12:54 -07:00
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
2020-04-07 12:51:52 -07:00
# scales img(bs,3,y,x) by ratio
if ratio == 1.0:
return img
else:
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
gs = 32 # (pixels) grid size
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
2020-03-15 18:39:54 -07:00
2020-03-13 20:12:54 -07:00
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, updates=0):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
2020-04-08 21:01:58 -07:00
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
2020-03-13 20:12:54 -07:00
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
2020-03-13 20:12:54 -07:00
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
2020-03-16 17:51:40 -07:00
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
2020-03-14 16:23:14 -07:00
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)