This commit is contained in:
Glenn Jocher
2019-04-27 17:51:59 +02:00
parent 8f1becd55c
commit d25190e15b
2 changed files with 6 additions and 4 deletions
+2 -2
View File
@@ -49,11 +49,11 @@ def model_info(model):
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
def labels_to_class_weights(labels):
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int)
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
weights /= weights.sum()
return torch.Tensor(weights)