Update labels_to_image_weights() (#1576)
This commit is contained in:
parent
f28f862245
commit
bc5c898c93
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
@ -12,7 +11,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@ -22,13 +21,10 @@ from utils.google_utils import gsutil_getsize
|
|||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
from utils.torch_utils import init_torch_seeds
|
from utils.torch_utils import init_torch_seeds
|
||||||
|
|
||||||
# Set printoptions
|
# Settings
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||||
matplotlib.rc('font', **{'size': 11})
|
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||||
|
|
||||||
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
|
|
||||||
cv2.setNumThreads(0)
|
|
||||||
|
|
||||||
|
|
||||||
def set_logging(rank=-1):
|
def set_logging(rank=-1):
|
||||||
@ -121,9 +117,8 @@ def labels_to_class_weights(labels, nc=80):
|
|||||||
|
|
||||||
|
|
||||||
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
||||||
# Produces image weights based on class mAPs
|
# Produces image weights based on class_weights and image contents
|
||||||
n = len(labels)
|
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
|
||||||
class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
|
|
||||||
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
|
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
|
||||||
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
|
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
|
||||||
return image_weights
|
return image_weights
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from utils.general import xywh2xyxy, xyxy2xywh
|
|||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
|
matplotlib.rc('font', **{'size': 11})
|
||||||
matplotlib.use('Agg') # for writing to files only
|
matplotlib.use('Agg') # for writing to files only
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user