Update .pre-commit-config.yaml (#2019)
* Update .pre-commit-config.yaml * Update __init__.py * Update .pre-commit-config.yaml * Precommit updates
This commit is contained in:
+8
-8
@@ -380,11 +380,11 @@ class DetectMultiBackend(nn.Module):
|
||||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||
if network.get_parameters()[0].get_layout().empty:
|
||||
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
||||
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
||||
batch_dim = get_batch(network)
|
||||
if batch_dim.is_static:
|
||||
batch_size = batch_dim.get_length()
|
||||
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
||||
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
|
||||
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
|
||||
elif engine: # TensorRT
|
||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||
@@ -431,7 +431,7 @@ class DetectMultiBackend(nn.Module):
|
||||
import tensorflow as tf
|
||||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
@@ -445,7 +445,7 @@ class DetectMultiBackend(nn.Module):
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
@@ -467,9 +467,9 @@ class DetectMultiBackend(nn.Module):
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
with zipfile.ZipFile(w, "r") as model:
|
||||
with zipfile.ZipFile(w, 'r') as model:
|
||||
meta_file = model.namelist()[0]
|
||||
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
||||
stride, names = int(meta['stride']), meta['names']
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError('ERROR: YOLOv3 TF.js inference is not supported')
|
||||
@@ -491,7 +491,7 @@ class DetectMultiBackend(nn.Module):
|
||||
check_requirements('tritonclient[all]')
|
||||
from utils.triton import TritonRemoteModel
|
||||
model = TritonRemoteModel(url=w)
|
||||
nhwc = model.runtime.startswith("tensorflow")
|
||||
nhwc = model.runtime.startswith('tensorflow')
|
||||
else:
|
||||
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
||||
|
||||
@@ -608,7 +608,7 @@ class DetectMultiBackend(nn.Module):
|
||||
url = urlparse(p) # if url may be Triton inference server
|
||||
types = [s in Path(p).name for s in sf]
|
||||
types[8] &= not types[9] # tflite &= not edgetpu
|
||||
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
|
||||
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
||||
return types + [triton]
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -45,4 +45,4 @@ head:
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
||||
]
|
||||
|
||||
@@ -45,4 +45,4 @@ head:
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
||||
]
|
||||
|
||||
+6
-6
@@ -356,7 +356,7 @@ class TFUpsample(keras.layers.Layer):
|
||||
# TF version of torch.nn.Upsample()
|
||||
def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
|
||||
super().__init__()
|
||||
assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
|
||||
assert scale_factor % 2 == 0, 'scale_factor must be multiple of 2'
|
||||
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
|
||||
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
|
||||
# with default arguments: align_corners=False, half_pixel_centers=False
|
||||
@@ -371,7 +371,7 @@ class TFConcat(keras.layers.Layer):
|
||||
# TF version of torch.concat()
|
||||
def __init__(self, dimension=1, w=None):
|
||||
super().__init__()
|
||||
assert dimension == 1, "convert only NCHW to NHWC concat"
|
||||
assert dimension == 1, 'convert only NCHW to NHWC concat'
|
||||
self.d = 3
|
||||
|
||||
def call(self, inputs):
|
||||
@@ -523,17 +523,17 @@ class AgnosticNMS(keras.layers.Layer):
|
||||
selected_boxes = tf.gather(boxes, selected_inds)
|
||||
padded_boxes = tf.pad(selected_boxes,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
|
||||
mode="CONSTANT",
|
||||
mode='CONSTANT',
|
||||
constant_values=0.0)
|
||||
selected_scores = tf.gather(scores_inp, selected_inds)
|
||||
padded_scores = tf.pad(selected_scores,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
||||
mode="CONSTANT",
|
||||
mode='CONSTANT',
|
||||
constant_values=-1.0)
|
||||
selected_classes = tf.gather(class_inds, selected_inds)
|
||||
padded_classes = tf.pad(selected_classes,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
||||
mode="CONSTANT",
|
||||
mode='CONSTANT',
|
||||
constant_values=-1.0)
|
||||
valid_detections = tf.shape(selected_inds)[0]
|
||||
return padded_boxes, padded_scores, padded_classes, valid_detections
|
||||
@@ -603,6 +603,6 @@ def main(opt):
|
||||
run(**vars(opt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
|
||||
Reference in New Issue
Block a user