diff --git a/train.py b/train.py index 5a229907..9aa6f683 100644 --- a/train.py +++ b/train.py @@ -199,7 +199,7 @@ def train(): # Dataloader batch_size = min(batch_size, len(dataset)) - nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) + nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers print('Using %g dataloader workers' % nw) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,