updates
This commit is contained in:
@@ -38,7 +38,7 @@ def test(cfg,
|
|||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_darknet_weights(model, weights)
|
load_darknet_weights(model, weights)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
else: # called by train.py
|
else: # called by train.py
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
|
|||||||
Reference in New Issue
Block a user