support onnx to tensorrt convert (#114)
This commit is contained in:
parent
4f6e390c99
commit
96390ed201
@ -73,7 +73,7 @@ def detect(save_img=False):
|
|||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
t1 = time_synchronized()
|
t1 = time_synchronized()
|
||||||
pred = model(img, augment=opt.augment)[0]
|
pred = model(img, augment=opt.augment)
|
||||||
|
|
||||||
# Apply NMS
|
# Apply NMS
|
||||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
|
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
|
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
|
||||||
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
|
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
|
||||||
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||||
|
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
||||||
print(opt)
|
print(opt)
|
||||||
@ -68,6 +69,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
|
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
|
||||||
f = opt.weights.replace('.pt', '.onnx') # filename
|
f = opt.weights.replace('.pt', '.onnx') # filename
|
||||||
|
model.eval()
|
||||||
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
|
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
|
||||||
output_names=['classes', 'boxes'] if y is None else ['output'],
|
output_names=['classes', 'boxes'] if y is None else ['output'],
|
||||||
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
|
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
|
||||||
@ -76,6 +78,23 @@ if __name__ == '__main__':
|
|||||||
# Checks
|
# Checks
|
||||||
onnx_model = onnx.load(f) # load onnx model
|
onnx_model = onnx.load(f) # load onnx model
|
||||||
onnx.checker.check_model(onnx_model) # check onnx model
|
onnx.checker.check_model(onnx_model) # check onnx model
|
||||||
|
|
||||||
|
# # Metadata
|
||||||
|
# d = {'stride': int(max(model.stride))}
|
||||||
|
# for k, v in d.items():
|
||||||
|
# meta = onnx_model.metadata_props.add()
|
||||||
|
# meta.key, meta.value = k, str(v)
|
||||||
|
# onnx.save(onnx_model, f)
|
||||||
|
|
||||||
|
if opt.simplify:
|
||||||
|
try:
|
||||||
|
import onnxsim
|
||||||
|
|
||||||
|
print('\nStarting to simplify ONNX...')
|
||||||
|
onnx_model, check = onnxsim.simplify(onnx_model)
|
||||||
|
assert check, 'assert check failed'
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Simplifier failure: {e}')
|
||||||
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
|
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
|
||||||
print('ONNX export success, saved as %s' % f)
|
print('ONNX export success, saved as %s' % f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -50,11 +50,16 @@ class Detect(nn.Module):
|
|||||||
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
||||||
|
|
||||||
y = x[i].sigmoid()
|
y = x[i].sigmoid()
|
||||||
|
if not torch.onnx.is_in_onnx_export():
|
||||||
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||||
|
else:
|
||||||
|
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
||||||
|
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||||
|
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
||||||
z.append(y.view(bs, -1, self.no))
|
z.append(y.view(bs, -1, self.no))
|
||||||
|
|
||||||
return x if self.training else (torch.cat(z, 1), x)
|
return x if self.training else torch.cat(z, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_grid(nx=20, ny=20):
|
def _make_grid(nx=20, ny=20):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user