diff --git a/tools/keypoint.ipynb b/tools/keypoint.ipynb index 82a9aa4d..38127337 100644 --- a/tools/keypoint.ipynb +++ b/tools/keypoint.ipynb @@ -57,7 +57,8 @@ "outputs": [], "source": [ "output = non_max_suppression_kpt(output, 0.25, 0.65, nc=model.yaml['nc'], nkpt=model.yaml['nkpt'], kpt_label=True)\n", - "output = output_to_keypoint(output)\n", + "with torch.no_grad():\n", + " output = output_to_keypoint(output)\n", "nimg = image[0].permute(1, 2, 0) * 255\n", "nimg = nimg.cpu().numpy().astype(np.uint8)\n", "nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)\n",