This commit is contained in:
Glenn Jocher
2019-05-23 13:15:44 +02:00
parent 68b9df4dd4
commit 001193b9c7
2 changed files with 8 additions and 12 deletions
+6 -8
View File
@@ -130,7 +130,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False,
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False,
multi_scale=False):
with open(path, 'r') as f:
img_files = f.read().splitlines()
@@ -190,11 +190,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
# Preload images
if cache and (n < 1001): # preload all images into memory if possible
self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')]
# Preload labels (required for weighted CE training)
self.imgs = [None] * n
self.labels = [np.zeros((0, 5))] * n
iter = tqdm(self.label_files, desc='Reading labels') if n > 1000 else self.label_files
for i, file in enumerate(iter):
@@ -227,10 +224,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
label_path = self.label_files[index]
# Load image
if hasattr(self, 'imgs'): # preloaded
img = self.imgs[index]
else:
img = self.imgs[index]
if img is None:
img = cv2.imread(img_path) # BGR
if self.n < 1001:
self.imgs[index] = img # cache image into memory
assert img is not None, 'File Not Found ' + img_path
# Augment colorspace