Pycocotools best.pt after COCO train (#1593)
This commit is contained in:
parent
adc49abc71
commit
4a07280884
5
test.py
5
test.py
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -246,7 +245,7 @@ def test(data,
|
|||||||
# Save JSON
|
# Save JSON
|
||||||
if save_json and len(jdict):
|
if save_json and len(jdict):
|
||||||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
||||||
anno_json = glob.glob('../coco/annotations/instances_val*.json')[0] # annotations json
|
anno_json = '../coco/annotations/instances_val2017.json' # annotations json
|
||||||
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
|
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
|
||||||
print('\nEvaluating pycocotools mAP... saving %s...' % pred_json)
|
print('\nEvaluating pycocotools mAP... saving %s...' % pred_json)
|
||||||
with open(pred_json, 'w') as f:
|
with open(pred_json, 'w') as f:
|
||||||
@ -266,7 +265,7 @@ def test(data,
|
|||||||
eval.summarize()
|
eval.summarize()
|
||||||
map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
|
map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('ERROR: pycocotools unable to run: %s' % e)
|
print(f'pycocotools unable to run: {e}')
|
||||||
|
|
||||||
# Return results
|
# Return results
|
||||||
if not training:
|
if not training:
|
||||||
|
|||||||
33
train.py
33
train.py
@ -22,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import test # import test.py to get mAP after each epoch
|
import test # import test.py to get mAP after each epoch
|
||||||
|
from models.experimental import attempt_load
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils.autoanchor import check_anchors
|
from utils.autoanchor import check_anchors
|
||||||
from utils.datasets import create_dataloader
|
from utils.datasets import create_dataloader
|
||||||
@ -193,9 +194,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||||||
# Process 0
|
# Process 0
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
||||||
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
|
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
|
||||||
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
|
||||||
rank=-1, world_size=opt.world_size, workers=opt.workers)[0] # testloader
|
rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0]
|
||||||
|
|
||||||
if not opt.resume:
|
if not opt.resume:
|
||||||
labels = np.concatenate(dataset.labels, 0)
|
labels = np.concatenate(dataset.labels, 0)
|
||||||
@ -385,15 +386,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||||||
|
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
# Strip optimizers
|
# Strip optimizers
|
||||||
n = opt.name if opt.name.isnumeric() else ''
|
for f in [last, best]:
|
||||||
fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
|
if f.exists(): # is *.pt
|
||||||
for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
|
strip_optimizer(f) # strip optimizer
|
||||||
if f1.exists():
|
os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload
|
||||||
os.rename(f1, f2) # rename
|
|
||||||
if str(f2).endswith('.pt'): # is *.pt
|
# Plots
|
||||||
strip_optimizer(f2) # strip optimizer
|
|
||||||
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
|
|
||||||
# Finish
|
|
||||||
if plots:
|
if plots:
|
||||||
plot_results(save_dir=save_dir) # save as results.png
|
plot_results(save_dir=save_dir) # save as results.png
|
||||||
if wandb:
|
if wandb:
|
||||||
@ -401,6 +399,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||||||
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
||||||
if (save_dir / f).exists()]})
|
if (save_dir / f).exists()]})
|
||||||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||||
|
|
||||||
|
# Test best.pt
|
||||||
|
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
||||||
|
results, _, _ = test.test(opt.data,
|
||||||
|
batch_size=total_batch_size,
|
||||||
|
imgsz=imgsz_test,
|
||||||
|
model=attempt_load(best if best.exists() else last, device).half(),
|
||||||
|
single_cls=opt.single_cls,
|
||||||
|
dataloader=testloader,
|
||||||
|
save_dir=save_dir,
|
||||||
|
save_json=True, # use pycocotools
|
||||||
|
plots=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ def gsutil_getsize(url=''):
|
|||||||
|
|
||||||
def attempt_download(weights):
|
def attempt_download(weights):
|
||||||
# Attempt to download pretrained weights if not found locally
|
# Attempt to download pretrained weights if not found locally
|
||||||
weights = weights.strip().replace("'", '')
|
weights = str(weights).strip().replace("'", '')
|
||||||
file = Path(weights).name.lower()
|
file = Path(weights).name.lower()
|
||||||
|
|
||||||
msg = weights + ' missing, try downloading from https://github.com/ultralytics/yolov3/releases/'
|
msg = weights + ' missing, try downloading from https://github.com/ultralytics/yolov3/releases/'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user