updates
This commit is contained in:
+2
-2
@@ -49,11 +49,11 @@ def model_info(model):
|
||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
||||
|
||||
|
||||
def labels_to_class_weights(labels):
|
||||
def labels_to_class_weights(labels, nc=80):
|
||||
# Get class weights (inverse frequency) from training labels
|
||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||
classes = labels[:, 0].astype(np.int)
|
||||
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
|
||||
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
|
||||
weights /= weights.sum()
|
||||
return torch.Tensor(weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user