Fix ONNX inference code (#1928)
This commit is contained in:
parent
c2c113e5eb
commit
ae37b2daa7
@ -314,9 +314,11 @@ class DetectMultiBackend(nn.Module):
|
||||
net = cv2.dnn.readNetFromONNX(w)
|
||||
elif onnx: # ONNX Runtime
|
||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||
import onnxruntime
|
||||
session = onnxruntime.InferenceSession(w, None)
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||
else: # TensorFlow model (TFLite, pb, saved_model)
|
||||
import tensorflow as tf
|
||||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user