diff --git a/train.py b/train.py index ebf85095..cf02236b 100644 --- a/train.py +++ b/train.py @@ -223,7 +223,7 @@ def train(cfg, # Update image weights (optional) if dataset.image_weights: - w = model.class_weights.cpu().numpy() * (1 - maps) # class weights + w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx