From dc3e04087dc5f0e010dc6d8ce455723727a02881 Mon Sep 17 00:00:00 2001 From: Alexander <84590713+SashaAlderson@users.noreply.github.com> Date: Fri, 22 Jul 2022 00:07:08 +0700 Subject: [PATCH] fuse IDetect (#148) --- models/yolo.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/models/yolo.py b/models/yolo.py index 951452de..bdd29581 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -106,7 +106,41 @@ class IDetect(nn.Module): z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x) + + def fuseforward(self, x): + # x = x.copy() # for profiling + z = [] # inference output + self.training |= self.export + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = x[i].sigmoid() + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + def fuse(self): + print("IDetect.fuse") + # fuse ImplicitA and Convolution + for i in range(len(self.m)): + c1,c2,_,_ = self.m[i].weight.shape + c1_,c2_, _,_ = self.ia[i].implicit.shape + self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1) + + # fuse ImplicitM and Convolution + for i in range(len(self.m)): + c1,c2, _,_ = self.im[i].implicit.shape + self.m[i].bias *= self.im[i].implicit.reshape(c2) + self.m[i].weight *= self.im[i].implicit.transpose(0,1) + @staticmethod def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) @@ -537,6 +571,9 @@ class Model(nn.Module): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward + elif isinstance(m, IDetect): + m.fuse() + m.forward = m.fuseforward self.info() return self