diff --git a/train.py b/train.py index edef2a0d..dc0eb480 100644 --- a/train.py +++ b/train.py @@ -167,7 +167,7 @@ def train( p.requires_grad = False if epoch == 0 else True # Update image weights (optional) - w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights + w = model.class_weights.cpu().numpy() * (1 - maps) # class weights image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index