From f43170817cf1ebbf3edad08ba7743a70de7bea75 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Aug 2019 16:45:13 +0200 Subject: [PATCH] updates --- train.py | 4 +++- utils/datasets.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 4dfb53c3..68c99356 100644 --- a/train.py +++ b/train.py @@ -187,7 +187,8 @@ def train(cfg, augment=True, hyp=hyp, # augmentation hyperparameters rect=opt.rect, # rectangular training - image_weights=opt.img_weights) + image_weights=opt.img_weights, + cache_images=opt.cache_images) # Dataloader dataloader = torch.utils.data.DataLoader(dataset, @@ -352,6 +353,7 @@ if __name__ == '__main__': parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--img-weights', action='store_true', help='select training images by weight') + parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') opt = parser.parse_args() print(opt) diff --git a/utils/datasets.py b/utils/datasets.py index 16d08c40..83777ea3 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -156,7 +156,7 @@ class LoadWebcam: # for inference class LoadImagesAndLabels(Dataset): # for training/testing - def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False): + def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, cache_images=False): path = str(Path(path)) # os-agnostic with open(path, 'r') as f: self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic @@ -254,7 +254,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing assert nf > 0, 'No labels found. Recommend correcting image and label paths.' # Cache images into memory for faster training (~5GB) - cache_images = False if cache_images and augment: # if training for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'): # max 10k images img_path = self.img_files[i]