From 954cde65ab74a9cfb17e82cf92cc409cdadecf3e Mon Sep 17 00:00:00 2001 From: AlexeyAB84 Date: Thu, 28 Jul 2022 22:03:09 +0300 Subject: [PATCH] Fuse IAuxDetect --- models/yolo.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index ee250e6c..cf5ccadf 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -303,6 +303,8 @@ class IKeypoint(nn.Module): class IAuxDetect(nn.Module): stride = None # strides computed during build export = False # onnx export + end2end = False + include_nms = False def __init__(self, nc=80, anchors=(), ch=()): # detection layer super(IAuxDetect, self).__init__() @@ -338,17 +340,83 @@ class IAuxDetect(nn.Module): self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if not torch.onnx.is_in_onnx_export(): + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x[:self.nl]) + def fuseforward(self, x): + # x = x.copy() # for profiling + z = [] # inference output + self.training |= self.export + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = x[i].sigmoid() + if not torch.onnx.is_in_onnx_export(): + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) + z.append(y.view(bs, -1, self.no)) + + if self.training: + out = x + elif self.end2end: + out = torch.cat(z, 1) + elif self.include_nms: + z = self.convert(z) + out = (z, ) + else: + out = (torch.cat(z, 1), x) + + return out + + def fuse(self): + print("IAuxDetect.fuse") + # fuse ImplicitA and Convolution + for i in range(len(self.m)): + c1,c2,_,_ = self.m[i].weight.shape + c1_,c2_, _,_ = self.ia[i].implicit.shape + self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1) + + # fuse ImplicitM and Convolution + for i in range(len(self.m)): + c1,c2, _,_ = self.im[i].implicit.shape + self.m[i].bias *= self.im[i].implicit.reshape(c2) + self.m[i].weight *= self.im[i].implicit.transpose(0,1) + @staticmethod def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + def convert(self, z): + z = torch.cat(z, 1) + box = z[:, :, :4] + conf = z[:, :, 4:5] + score = z[:, :, 5:] + score *= conf + convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], + dtype=torch.float32, + device=z.device) + box @= convert_matrix + return (box, score) + class IBin(nn.Module): stride = None # strides computed during build @@ -623,7 +691,7 @@ class Model(nn.Module): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward - elif isinstance(m, IDetect): + elif isinstance(m, (IDetect, IAuxDetect)): m.fuse() m.forward = m.fuseforward self.info()