diff --git a/utils/utils.py b/utils/utils.py index 91d75e2c..8722aa01 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -337,21 +337,25 @@ def build_targets(model, targets): multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) nt = len(targets) - txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], [] + txy, twh, tcls, tbox, indices, av = [], [], [], [], [], [] for i in model.yolo_layers: - layer = model.module.module_list[i] if multi_gpu else model.module_list[i] + # get number of grid points and anchor vec for this yolo layer + if multi_gpu: + ng, anchor_vec = model.module.module_list[i].ng, model.module.module_list[i].anchor_vec + else: + ng, anchor_vec = model.module_list[i].ng, model.module_list[i].anchor_vec # iou of targets-anchors t, a = targets, [] - gwh = t[:, 4:6] * layer.ng + gwh = t[:, 4:6] * ng if nt: - iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0) + iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0) use_best_anchor = False if use_best_anchor: iou, a = iou.max(0) # best iou and anchor else: # use all anchors - na = len(layer.anchor_vec) # number of anchors + na = len(anchor_vec) # number of anchors a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) t = targets.repeat([na, 1]) gwh = gwh.repeat([na, 1]) @@ -365,7 +369,7 @@ def build_targets(model, targets): # Indices b, c = t[:, :2].long().t() # target image, class - gxy = t[:, 2:4] * layer.ng # grid x, y + gxy = t[:, 2:4] * ng # grid x, y gi, gj = gxy.long().t() # grid x, y indices indices.append((b, a, gj, gi)) @@ -375,18 +379,18 @@ def build_targets(model, targets): # GIoU tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids) - anchor_vec.append(layer.anchor_vec[a]) + av.append(anchor_vec[a]) # anchor vec # Width and height - twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method - # twh.append((gwh / layer.anchor_vec[a]) ** (1 / 3) / 2) # wh power method + twh.append(torch.log(gwh / anchor_vec[a])) # wh yolo method + # twh.append((gwh / anchor_vec[a]) ** (1 / 3) / 2) # wh power method # Class tcls.append(c) if c.shape[0]: # if any targets - assert c.max() <= layer.nc, 'Target classes exceed model classes' + assert c.max() <= model.nc, 'Target classes exceed model classes' - return txy, twh, tcls, tbox, indices, anchor_vec + return txy, twh, tcls, tbox, indices, av def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):