Update torch_utils.py (#1652)

This commit is contained in:
Glenn Jocher 2021-01-09 21:34:12 -08:00 committed by GitHub
parent d9b29951c1
commit 162773d968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,9 +3,11 @@
import logging
import math
import os
import subprocess
import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
@ -41,9 +43,17 @@ def init_torch_seeds(seed=0):
cudnn.benchmark, cudnn.deterministic = True, False
def git_describe():
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
if Path('.git').exists():
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
else:
return ''
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'Using torch {torch.__version__} ' # string
s = f'YOLOv3 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
@ -61,9 +71,9 @@ def select_device(device='', batch_size=None):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else:
s += 'CPU'
s += 'CPU\n'
logger.info(f'{s}\n') # skip a line
logger.info(s) # skip a line
return torch.device('cuda:0' if cuda else 'cpu')
@ -225,8 +235,8 @@ def load_classifier(name='resnet101', n=2):
return model
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
# scales img(bs,3,y,x) by ratio
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
if ratio == 1.0:
return img
else:
@ -234,7 +244,6 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
gs = 32 # (pixels) grid size
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean