From 4f890d13ee1b23a79c8f58922b68a2c4856745a3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 30 Nov 2020 16:47:28 +0100 Subject: [PATCH] Daemon thread plots (#1578) --- test.py | 23 ++++++++++++----------- train.py | 10 +++++----- utils/plots.py | 11 ++++++++--- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/test.py b/test.py index d62afd7d..4120057a 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ import glob import json import os from pathlib import Path +from threading import Thread import numpy as np import torch @@ -206,10 +207,10 @@ def test(data, # Plot images if plots and batch_i < 3: - f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename - plot_images(img, targets, paths, f, names) # labels - f = save_dir / f'test_batch{batch_i}_pred.jpg' - plot_images(img, output_to_target(output), paths, f, names) # predictions + f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels + Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() + f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions + Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start() # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy @@ -221,13 +222,6 @@ def test(data, else: nt = torch.zeros(1) - # Plots - if plots: - confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) - if wandb and wandb.run: - wandb.log({"Images": wandb_images}) - wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]}) - # Print results pf = '%20s' + '%12.3g' * 6 # print format print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) @@ -242,6 +236,13 @@ def test(data, if not training: print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) + # Plots + if plots: + confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) + if wandb and wandb.run: + wandb.log({"Images": wandb_images}) + wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]}) + # Save JSON if save_json and len(jdict): w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights diff --git a/train.py b/train.py index 3471496b..f0f778db 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,13 @@ import argparse import logging -import math import os import random import time from pathlib import Path +from threading import Thread from warnings import warn +import math import numpy as np import torch.distributed as dist import torch.nn as nn @@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): project='YOLOv3' if opt.project == 'runs/train' else Path(opt.project).stem, name=save_dir.stem, id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) + loggers = {'wandb': wandb} # loggers dict # Resume start_epoch, best_fitness = 0, 0.0 @@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - plot_labels(labels, save_dir=save_dir) + Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start() if tb_writer: tb_writer.add_histogram('classes', c, 0) - if wandb: - wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]}) # Anchors if not opt.noautoanchor: @@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Plot if plots and ni < 3: f = save_dir / f'train_batch{ni}.jpg' # filename - plot_images(images=imgs, targets=targets, paths=paths, fname=f) + Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() # if tb_writer: # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_graph(model, imgs) # add model to tensorboard diff --git a/utils/plots.py b/utils/plots.py index 9febcae5..8492b1a1 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx plt.savefig('test_study.png', dpi=300) -def plot_labels(labels, save_dir=''): +def plot_labels(labels, save_dir=Path(''), loggers=None): # plot dataset labels c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes nc = int(c.max() + 1) # number of classes @@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''): sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), diag_kws=dict(bins=50)) - plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200) + plt.savefig(save_dir / 'labels_correlogram.png', dpi=200) plt.close() except Exception as e: pass @@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''): for a in [0, 1, 2, 3]: for s in ['top', 'right', 'left', 'bottom']: ax[a].spines[s].set_visible(False) - plt.savefig(Path(save_dir) / 'labels.png', dpi=200) + plt.savefig(save_dir / 'labels.png', dpi=200) plt.close() + # loggers + for k, v in loggers.items() or {}: + if k == 'wandb' and v: + v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]}) + def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() # Plot hyperparameter evolution results in evolve.txt