diff --git a/detect.py b/detect.py index e7070b56..5fbe3299 100755 --- a/detect.py +++ b/detect.py @@ -30,9 +30,7 @@ def detect( if weights.endswith('.pt'): # pytorch format if weights.endswith('weights/yolov3.pt') and not os.path.isfile(weights): os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights) - checkpoint = torch.load(weights, map_location='cpu') - model.load_state_dict(checkpoint['model']) - del checkpoint + model.load_state_dict(torch.load(weights, map_location='cpu')['model']) else: # darknet format load_darknet_weights(model, weights) diff --git a/test.py b/test.py index bcb54f99..8b429660 100644 --- a/test.py +++ b/test.py @@ -29,9 +29,7 @@ def test( # Load weights if weights.endswith('.pt'): # pytorch format - checkpoint = torch.load(weights, map_location='cpu') - model.load_state_dict(checkpoint['model']) - del checkpoint + model.load_state_dict(torch.load(weights, map_location='cpu')['model']) else: # darknet format load_darknet_weights(model, weights)