Update labels_to_image_weights() (#1576)

This commit is contained in:
Glenn Jocher 2020-11-28 12:25:57 +01:00 committed by GitHub
parent f28f862245
commit bc5c898c93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 10 deletions

View File

@ -2,7 +2,6 @@
import glob import glob
import logging import logging
import math
import os import os
import platform import platform
import random import random
@ -12,7 +11,7 @@ import time
from pathlib import Path from pathlib import Path
import cv2 import cv2
import matplotlib import math
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
@ -22,13 +21,10 @@ from utils.google_utils import gsutil_getsize
from utils.metrics import fitness from utils.metrics import fitness
from utils.torch_utils import init_torch_seeds from utils.torch_utils import init_torch_seeds
# Set printoptions # Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
matplotlib.rc('font', **{'size': 11}) cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
cv2.setNumThreads(0)
def set_logging(rank=-1): def set_logging(rank=-1):
@ -121,9 +117,8 @@ def labels_to_class_weights(labels, nc=80):
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class mAPs # Produces image weights based on class_weights and image contents
n = len(labels) class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1) image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights return image_weights

View File

@ -20,6 +20,7 @@ from utils.general import xywh2xyxy, xyxy2xywh
from utils.metrics import fitness from utils.metrics import fitness
# Settings # Settings
matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg') # for writing to files only matplotlib.use('Agg') # for writing to files only