From 86cc5d0c4f1e26f122dcba079ce564548c07fa86 Mon Sep 17 00:00:00 2001 From: ksnzh Date: Thu, 28 Jul 2022 00:25:06 +0800 Subject: [PATCH] fix IDetect fuseforrward onnx export (#332) --- models/yolo.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index a3f05eb7..ee250e6c 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -146,8 +146,13 @@ class IDetect(nn.Module): self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if not torch.onnx.is_in_onnx_export(): + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) if self.training: