updates
This commit is contained in:
+6
-8
@@ -130,7 +130,7 @@ class LoadWebcam: # for inference
|
||||
|
||||
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False,
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False,
|
||||
multi_scale=False):
|
||||
with open(path, 'r') as f:
|
||||
img_files = f.read().splitlines()
|
||||
@@ -190,11 +190,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
|
||||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
|
||||
|
||||
# Preload images
|
||||
if cache and (n < 1001): # preload all images into memory if possible
|
||||
self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')]
|
||||
|
||||
# Preload labels (required for weighted CE training)
|
||||
self.imgs = [None] * n
|
||||
self.labels = [np.zeros((0, 5))] * n
|
||||
iter = tqdm(self.label_files, desc='Reading labels') if n > 1000 else self.label_files
|
||||
for i, file in enumerate(iter):
|
||||
@@ -227,10 +224,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
label_path = self.label_files[index]
|
||||
|
||||
# Load image
|
||||
if hasattr(self, 'imgs'): # preloaded
|
||||
img = self.imgs[index]
|
||||
else:
|
||||
img = self.imgs[index]
|
||||
if img is None:
|
||||
img = cv2.imread(img_path) # BGR
|
||||
if self.n < 1001:
|
||||
self.imgs[index] = img # cache image into memory
|
||||
assert img is not None, 'File Not Found ' + img_path
|
||||
|
||||
# Augment colorspace
|
||||
|
||||
Reference in New Issue
Block a user