Update torch_utils.py (#1652)
This commit is contained in:
parent
d9b29951c1
commit
162773d968
@ -3,9 +3,11 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
@ -41,9 +43,17 @@ def init_torch_seeds(seed=0):
|
||||
cudnn.benchmark, cudnn.deterministic = True, False
|
||||
|
||||
|
||||
def git_describe():
|
||||
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
if Path('.git').exists():
|
||||
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def select_device(device='', batch_size=None):
|
||||
# device = 'cpu' or '0' or '0,1,2,3'
|
||||
s = f'Using torch {torch.__version__} ' # string
|
||||
s = f'YOLOv3 {git_describe()} torch {torch.__version__} ' # string
|
||||
cpu = device.lower() == 'cpu'
|
||||
if cpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
@ -61,9 +71,9 @@ def select_device(device='', batch_size=None):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
||||
else:
|
||||
s += 'CPU'
|
||||
s += 'CPU\n'
|
||||
|
||||
logger.info(f'{s}\n') # skip a line
|
||||
logger.info(s) # skip a line
|
||||
return torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
|
||||
@ -225,8 +235,8 @@ def load_classifier(name='resnet101', n=2):
|
||||
return model
|
||||
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
||||
# scales img(bs,3,y,x) by ratio
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||||
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
else:
|
||||
@ -234,7 +244,6 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
gs = 32 # (pixels) grid size
|
||||
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user