This commit is contained in:
Glenn Jocher
2019-08-17 14:08:10 +02:00
parent 321bd95764
commit a1200ef130
3 changed files with 3 additions and 4 deletions
+1 -1
View File
@@ -62,7 +62,7 @@ def labels_to_class_weights(labels, nc=80):
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.Tensor(weights)
return torch.from_numpy(weights)
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):