minor fix
This commit is contained in:
+17
-1
@@ -84,6 +84,7 @@ class Detect(nn.Module):
|
||||
class IDetect(nn.Module):
|
||||
stride = None # strides computed during build
|
||||
export = False # onnx export
|
||||
include_nms = False
|
||||
|
||||
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
|
||||
super(IDetect, self).__init__()
|
||||
@@ -139,7 +140,10 @@ class IDetect(nn.Module):
|
||||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
z.append(y.view(bs, -1, self.no))
|
||||
|
||||
return x if self.training else (torch.cat(z, 1), x)
|
||||
if self.include_nms:
|
||||
z = self.convert(z)
|
||||
|
||||
return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)
|
||||
|
||||
def fuse(self):
|
||||
print("IDetect.fuse")
|
||||
@@ -160,6 +164,18 @@ class IDetect(nn.Module):
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
||||
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
||||
|
||||
def convert(self, z):
|
||||
z = torch.cat(z, 1)
|
||||
box = z[:, :, :4]
|
||||
conf = z[:, :, 4:5]
|
||||
score = z[:, :, 5:]
|
||||
score *= conf
|
||||
convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
||||
dtype=torch.float32,
|
||||
device=z.device)
|
||||
box @= convert_matrix
|
||||
return (box, score)
|
||||
|
||||
|
||||
class IKeypoint(nn.Module):
|
||||
stride = None # strides computed during build
|
||||
|
||||
Reference in New Issue
Block a user