diff --git a/train.py b/train.py index b87d6879..0c5deaa0 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,7 @@ def train(hyp, opt, device, tb_writer=None): loggers = {'wandb': None} # loggers dict if rank in [-1, 0]: opt.hyp = hyp # add hyperparameters - run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None + run_id = torch.load(weights, map_location=device).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) loggers['wandb'] = wandb_logger.wandb data_dict = wandb_logger.data_dict