From 50df252c4b21540b5e878ced79f8d7959ced51d4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 12 Apr 2019 14:58:19 +0200 Subject: [PATCH] updates --- train.py | 5 ++--- utils/utils.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 720e0b1e..06878e6c 100644 --- a/train.py +++ b/train.py @@ -121,10 +121,9 @@ def train( for i, (imgs, targets, _, _) in enumerate(dataloader): imgs = imgs.to(device) targets = targets.to(device) - nt = len(targets) - if nt == 0: # if no targets continue - continue + # if nt == 0: # if no targets continue + # continue # Plot images with bounding boxes if epoch == 0 and i == 0: diff --git a/utils/utils.py b/utils/utils.py index cb2ea0ff..d060b9c1 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -321,7 +321,7 @@ def build_targets(model, targets): # Class tcls.append(c) - if nt: + if c.shape[0]: assert c.max() <= layer.nC, 'Target classes exceed model classes' return txy, twh, tcls, indices