fix a model inference problem on cpu (#502)
* fix model inference problem on cpu * Update keypoint.ipynb
This commit is contained in:
parent
36ce6b2087
commit
064c71e7c2
@ -25,10 +25,12 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"weigths = torch.load('yolov7-w6-pose.pt')\n",
|
"weigths = torch.load('yolov7-w6-pose.pt', map_location=device)\n",
|
||||||
"model = weigths['model']\n",
|
"model = weigths['model']\n",
|
||||||
"model = model.half().to(device)\n",
|
"_ = model.float().eval()\n",
|
||||||
"_ = model.eval()"
|
"\n",
|
||||||
|
"if torch.cuda.is_available():\n",
|
||||||
|
" model.half().to(device)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -43,9 +45,9 @@
|
|||||||
"image_ = image.copy()\n",
|
"image_ = image.copy()\n",
|
||||||
"image = transforms.ToTensor()(image)\n",
|
"image = transforms.ToTensor()(image)\n",
|
||||||
"image = torch.tensor(np.array([image.numpy()]))\n",
|
"image = torch.tensor(np.array([image.numpy()]))\n",
|
||||||
"image = image.to(device)\n",
|
|
||||||
"image = image.half()\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"if torch.cuda.is_available():\n",
|
||||||
|
" image = image.half().to(device) \n",
|
||||||
"output, _ = model(image)"
|
"output, _ = model(image)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -118,7 +120,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.10"
|
"version": "3.9.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user