This commit is contained in:
Glenn Jocher
2019-12-04 23:02:32 -08:00
parent e27b124828
commit 63c2736c12
3 changed files with 45 additions and 31 deletions
+16 -14
View File
@@ -255,7 +255,7 @@ class LoadStreams: # multiple IP or RTSP cameras
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_labels=False, cache_images=False):
path = str(Path(path)) # os-agnostic
with open(path, 'r') as f:
@@ -319,7 +319,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.labels = [np.zeros((0, 5))] * n
extract_bounding_boxes = False
create_datasubset = False
pbar = tqdm(self.label_files, desc='Reading labels')
pbar = tqdm(self.label_files, desc='Caching labels')
nm, nf, ne, ns = 0, 0, 0, 0 # number missing, number found, number empty, number datasubset
for i, file in enumerate(pbar):
try:
@@ -370,13 +370,17 @@ class LoadImagesAndLabels(Dataset): # for training/testing
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
pbar.desc = 'Reading labels (%g found, %g missing, %g empty for %g images)' % (nf, nm, ne, n)
pbar.desc = 'Caching labels (%g found, %g missing, %g empty for %g images)' % (nf, nm, ne, n)
assert nf > 0, 'No labels found. Recommend correcting image and label paths.'
# Cache images into memory for faster training (~5GB)
if cache_images and augment: # if training
for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'): # max 10k images
# Cache images into memory for faster training (WARNING: Large datasets may exceed system RAM)
if cache_images: # if training
gb = 0 # Gigabytes of cached images
pbar = tqdm(range(len(self.img_files)), desc='Caching images')
for i in pbar: # max 10k images
self.imgs[i] = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
detect_corrupted_images = False
@@ -503,10 +507,10 @@ def load_image(self, index):
img_path = self.img_files[index]
img = cv2.imread(img_path) # BGR
assert img is not None, 'Image Not Found ' + img_path
r = self.img_size / max(img.shape) # size ratio
if self.augment: # if training (NOT testing), downsize to inference shape
r = self.img_size / max(img.shape) # resize image to img_size
if (r < 1) or ((r > 1) and self.augment): # always resize down, only resize up if training with augmentation
h, w = img.shape[:2]
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # _LINEAR fastest
return cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # _LINEAR fastest
return img
@@ -569,13 +573,11 @@ def load_mosaic(self, index):
# Concat/clip labels
if len(labels4):
labels4 = np.concatenate(labels4, 0)
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use before random_affine
# np.clip(labels4[:, 1:], s / 2, 1.5 * s, out=labels4[:, 1:])
# labels4[:, 1:] -= s / 2
# img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)]
# np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
# Augment
# img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
img4, labels4 = random_affine(img4, labels4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],