Fix ONNX inference code (#1928)

This commit is contained in:
Sahil Chachra 2022-04-11 16:10:56 +05:30 committed by GitHub
parent c2c113e5eb
commit ae37b2daa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -314,9 +314,11 @@ class DetectMultiBackend(nn.Module):
net = cv2.dnn.readNetFromONNX(w) net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...') LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) cuda = torch.cuda.is_available()
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime import onnxruntime
session = onnxruntime.InferenceSession(w, None) providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(w, providers=providers)
else: # TensorFlow model (TFLite, pb, saved_model) else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt