diff --git a/test.py b/test.py index 2b6b8f23..f6a68859 100644 --- a/test.py +++ b/test.py @@ -38,7 +38,7 @@ def test(cfg, else: # darknet format load_darknet_weights(model, weights) - if torch.cuda.device_count() > 1: + if device.type != 'cpu' and torch.cuda.device_count() > 1: model = nn.DataParallel(model) else: # called by train.py device = next(model.parameters()).device # get model device