diff --git a/test.py b/test.py index 54cac3c1..3f0b115a 100644 --- a/test.py +++ b/test.py @@ -48,7 +48,7 @@ def test(cfg, dataset = LoadImagesAndLabels(test_path, img_size, batch_size) dataloader = DataLoader(dataset, batch_size=batch_size, - num_workers=min(os.cpu_count(), batch_size), + num_workers=min([os.cpu_count(), batch_size, 16]), pin_memory=True, collate_fn=dataset.collate_fn) diff --git a/train.py b/train.py index c85c8c7b..ae31349b 100644 --- a/train.py +++ b/train.py @@ -194,7 +194,7 @@ def train(): # Dataloader dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - num_workers=min(os.cpu_count(), batch_size), + num_workers=min([os.cpu_count(), batch_size, 16]), shuffle=not opt.rect, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn)