diff --git a/detect.py b/detect.py index 5648a91b..5e0c4416 100644 --- a/detect.py +++ b/detect.py @@ -84,7 +84,8 @@ def detect(save_img=False): # Inference t1 = time_synchronized() - pred = model(img, augment=opt.augment)[0] + with torch.no_grad(): # Calculating gradients would cause a GPU memory leak + pred = model(img, augment=opt.augment)[0] t2 = time_synchronized() # Apply NMS