This commit is contained in:
Glenn Jocher
2019-02-08 22:43:05 +01:00
parent d6abdaf8d0
commit c2436d8197
7 changed files with 107 additions and 161 deletions
+8 -20
View File
@@ -13,7 +13,7 @@ from utils.utils import xyxy2xywh
class load_images(): # for inference
def __init__(self, path, batch_size=1, img_size=416):
def __init__(self, path, img_size=416):
if os.path.isdir(path):
image_format = ['.jpg', '.jpeg', '.png', '.tif']
self.files = sorted(glob.glob('%s/*.*' % path))
@@ -22,43 +22,37 @@ class load_images(): # for inference
self.files = [path]
self.nF = len(self.files) # number of image files
self.nB = math.ceil(self.nF / batch_size) # number of batches
self.batch_size = batch_size
self.height = img_size
assert self.nF > 0, 'No images found in path %s' % path
# RGB normalization values
# self.rgb_mean = np.array([60.134, 49.697, 40.746], dtype=np.float32).reshape((3, 1, 1))
# self.rgb_std = np.array([29.99, 24.498, 22.046], dtype=np.float32).reshape((3, 1, 1))
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
if self.count == self.nB:
if self.count == self.nF:
raise StopIteration
img_path = self.files[self.count]
# Read image
img = cv2.imread(img_path) # BGR
img0 = cv2.imread(img_path) # BGR
assert img0 is not None, 'Failed to load ' + img_path
# Padded resize
img, _, _, _ = resize_square(img, height=self.height, color=(127.5, 127.5, 127.5))
img, _, _, _ = resize_square(img0, height=self.height, color=(127.5, 127.5, 127.5))
# Normalize RGB
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img, dtype=np.float32)
# img -= self.rgb_mean
# img /= self.rgb_std
img /= 255.0
return [img_path], img
# cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
return img_path, img, img0
def __len__(self):
return self.nB # number of batches
return self.nF # number of files
class load_images_and_labels(): # for training
@@ -81,10 +75,6 @@ class load_images_and_labels(): # for training
assert self.nB > 0, 'No images found in path %s' % path
# RGB normalization values
# self.rgb_mean = np.array([60.134, 49.697, 40.746], dtype=np.float32).reshape((1, 3, 1, 1))
# self.rgb_std = np.array([29.99, 24.498, 22.046], dtype=np.float32).reshape((1, 3, 1, 1))
def __iter__(self):
self.count = -1
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
@@ -191,8 +181,6 @@ class load_images_and_labels(): # for training
# Normalize
img_all = np.stack(img_all)[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB and cv2 to pytorch
img_all = np.ascontiguousarray(img_all, dtype=np.float32)
# img_all -= self.rgb_mean
# img_all /= self.rgb_std
img_all /= 255.0
return torch.from_numpy(img_all), labels_all
+1 -1
View File
@@ -20,7 +20,7 @@ def parse_model_config(path):
return module_defs
def parse_data_config(path):
def parse_data_cfg(path):
"""Parses the data configuration file"""
options = dict()
options['gpus'] = '0,1,2,3'
+1
View File
@@ -21,4 +21,5 @@ def select_device(force_cpu=False):
device = torch.device('cpu')
else:
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
print('Using ' + str(device) + '\n')
return device