updates
This commit is contained in:
+8
-20
@@ -13,7 +13,7 @@ from utils.utils import xyxy2xywh
|
||||
|
||||
|
||||
class load_images(): # for inference
|
||||
def __init__(self, path, batch_size=1, img_size=416):
|
||||
def __init__(self, path, img_size=416):
|
||||
if os.path.isdir(path):
|
||||
image_format = ['.jpg', '.jpeg', '.png', '.tif']
|
||||
self.files = sorted(glob.glob('%s/*.*' % path))
|
||||
@@ -22,43 +22,37 @@ class load_images(): # for inference
|
||||
self.files = [path]
|
||||
|
||||
self.nF = len(self.files) # number of image files
|
||||
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
||||
self.batch_size = batch_size
|
||||
self.height = img_size
|
||||
|
||||
assert self.nF > 0, 'No images found in path %s' % path
|
||||
|
||||
# RGB normalization values
|
||||
# self.rgb_mean = np.array([60.134, 49.697, 40.746], dtype=np.float32).reshape((3, 1, 1))
|
||||
# self.rgb_std = np.array([29.99, 24.498, 22.046], dtype=np.float32).reshape((3, 1, 1))
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.count += 1
|
||||
if self.count == self.nB:
|
||||
if self.count == self.nF:
|
||||
raise StopIteration
|
||||
img_path = self.files[self.count]
|
||||
|
||||
# Read image
|
||||
img = cv2.imread(img_path) # BGR
|
||||
img0 = cv2.imread(img_path) # BGR
|
||||
assert img0 is not None, 'Failed to load ' + img_path
|
||||
|
||||
# Padded resize
|
||||
img, _, _, _ = resize_square(img, height=self.height, color=(127.5, 127.5, 127.5))
|
||||
img, _, _, _ = resize_square(img0, height=self.height, color=(127.5, 127.5, 127.5))
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img, dtype=np.float32)
|
||||
# img -= self.rgb_mean
|
||||
# img /= self.rgb_std
|
||||
img /= 255.0
|
||||
|
||||
return [img_path], img
|
||||
# cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
|
||||
return img_path, img, img0
|
||||
|
||||
def __len__(self):
|
||||
return self.nB # number of batches
|
||||
return self.nF # number of files
|
||||
|
||||
|
||||
class load_images_and_labels(): # for training
|
||||
@@ -81,10 +75,6 @@ class load_images_and_labels(): # for training
|
||||
|
||||
assert self.nB > 0, 'No images found in path %s' % path
|
||||
|
||||
# RGB normalization values
|
||||
# self.rgb_mean = np.array([60.134, 49.697, 40.746], dtype=np.float32).reshape((1, 3, 1, 1))
|
||||
# self.rgb_std = np.array([29.99, 24.498, 22.046], dtype=np.float32).reshape((1, 3, 1, 1))
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
||||
@@ -191,8 +181,6 @@ class load_images_and_labels(): # for training
|
||||
# Normalize
|
||||
img_all = np.stack(img_all)[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB and cv2 to pytorch
|
||||
img_all = np.ascontiguousarray(img_all, dtype=np.float32)
|
||||
# img_all -= self.rgb_mean
|
||||
# img_all /= self.rgb_std
|
||||
img_all /= 255.0
|
||||
|
||||
return torch.from_numpy(img_all), labels_all
|
||||
|
||||
@@ -20,7 +20,7 @@ def parse_model_config(path):
|
||||
|
||||
return module_defs
|
||||
|
||||
def parse_data_config(path):
|
||||
def parse_data_cfg(path):
|
||||
"""Parses the data configuration file"""
|
||||
options = dict()
|
||||
options['gpus'] = '0,1,2,3'
|
||||
|
||||
@@ -21,4 +21,5 @@ def select_device(force_cpu=False):
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu')
|
||||
print('Using ' + str(device) + '\n')
|
||||
return device
|
||||
|
||||
Reference in New Issue
Block a user