Fix ONNX inference code (#1928)
This commit is contained in:
parent
c2c113e5eb
commit
ae37b2daa7
@ -314,9 +314,11 @@ class DetectMultiBackend(nn.Module):
|
|||||||
net = cv2.dnn.readNetFromONNX(w)
|
net = cv2.dnn.readNetFromONNX(w)
|
||||||
elif onnx: # ONNX Runtime
|
elif onnx: # ONNX Runtime
|
||||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||||
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
|
cuda = torch.cuda.is_available()
|
||||||
|
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
session = onnxruntime.InferenceSession(w, None)
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||||
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||||
else: # TensorFlow model (TFLite, pb, saved_model)
|
else: # TensorFlow model (TFLite, pb, saved_model)
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user