From ae37b2daa74c599d640a7b9698eeafd64265f999 Mon Sep 17 00:00:00 2001 From: Sahil Chachra <37156032+SahilChachra@users.noreply.github.com> Date: Mon, 11 Apr 2022 16:10:56 +0530 Subject: [PATCH] Fix ONNX inference code (#1928) --- models/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index 82b348ae..a76fc628 100644 --- a/models/common.py +++ b/models/common.py @@ -314,9 +314,11 @@ class DetectMultiBackend(nn.Module): net = cv2.dnn.readNetFromONNX(w) elif onnx: # ONNX Runtime LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) + cuda = torch.cuda.is_available() + check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) import onnxruntime - session = onnxruntime.InferenceSession(w, None) + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + session = onnxruntime.InferenceSession(w, providers=providers) else: # TensorFlow model (TFLite, pb, saved_model) import tensorflow as tf if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt