diff --git a/utils/plots.py b/utils/plots.py index fdacc438..fdd8d0e8 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -439,8 +439,8 @@ def output_to_keypoint(output): for i, o in enumerate(output): kpts = o[:,6:] o = o[:,:6] - for index, (*box, conf, cls) in enumerate(o.cpu().numpy()): - targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.cpu().numpy()[index])]) + for index, (*box, conf, cls) in enumerate(o.detach().cpu().numpy()): + targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.detach().cpu().numpy()[index])]) return np.array(targets)