diff --git a/models/common.py b/models/common.py index 82b348ae..a76fc628 100644 --- a/models/common.py +++ b/models/common.py @@ -314,9 +314,11 @@ class DetectMultiBackend(nn.Module): net = cv2.dnn.readNetFromONNX(w) elif onnx: # ONNX Runtime LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) + cuda = torch.cuda.is_available() + check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) import onnxruntime - session = onnxruntime.InferenceSession(w, None) + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + session = onnxruntime.InferenceSession(w, providers=providers) else: # TensorFlow model (TFLite, pb, saved_model) import tensorflow as tf if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt