diff --git a/train.py b/train.py index 1bf730d8..3c74bdfe 100644 --- a/train.py +++ b/train.py @@ -121,9 +121,7 @@ def train( if torch.cuda.device_count() > 1: dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank) model = torch.nn.parallel.DistributedDataParallel(model) - sampler = torch.utils.data.distributed.DistributedSampler(dataset) - else: - sampler = None + # sampler = torch.utils.data.distributed.DistributedSampler(dataset) # Dataloader dataloader = DataLoader(dataset, @@ -131,8 +129,7 @@ def train( num_workers=opt.num_workers, shuffle=True, pin_memory=True, - collate_fn=dataset.collate_fn, - sampler=sampler) + collate_fn=dataset.collate_fn) # Mixed precision training https://github.com/NVIDIA/apex # install help: https://github.com/NVIDIA/apex/issues/259