diff --git a/detect.py b/detect.py index 7953f30e..2f720cd8 100755 --- a/detect.py +++ b/detect.py @@ -38,12 +38,12 @@ def main(opt): if weights_path.endswith('.pt'): # pytorch format if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path): os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path) - else: # darknet format - load_weights(model, weights_path) checkpoint = torch.load(weights_path, map_location='cpu') model.load_state_dict(checkpoint['model']) del checkpoint + else: # darknet format + load_weights(model, weights_path) # current = model.state_dict() # saved = checkpoint['model']