updates
This commit is contained in:
+2
-1
@@ -280,6 +280,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||||||
|
|
||||||
def build_targets(model, targets):
|
def build_targets(model, targets):
|
||||||
# targets = [image, class, x, y, w, h]
|
# targets = [image, class, x, y, w, h]
|
||||||
|
iou_thres = model.hyp['iou_t'] # hyperparameter
|
||||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
@@ -298,7 +299,7 @@ def build_targets(model, targets):
|
|||||||
# reject below threshold ious (OPTIONAL, increases P, lowers R)
|
# reject below threshold ious (OPTIONAL, increases P, lowers R)
|
||||||
reject = True
|
reject = True
|
||||||
if reject:
|
if reject:
|
||||||
j = iou > model.hyp['iou_t'] # hyperparameter
|
j = iou > iou_thres
|
||||||
t, a, gwh = targets[j], a[j], gwh[j]
|
t, a, gwh = targets[j], a[j], gwh[j]
|
||||||
|
|
||||||
# Indices
|
# Indices
|
||||||
|
|||||||
Reference in New Issue
Block a user