This commit is contained in:
Glenn Jocher
2019-05-08 13:06:24 +02:00
parent 9ee59fe694
commit a8f0a3fede
6 changed files with 5050 additions and 45 deletions
+9 -6
View File
@@ -130,7 +130,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=False):
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
@@ -181,8 +181,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch = bi # batch index of image
# Preload images
# if n < 200: # preload all images into memory if possible
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
if n < 1001: # preload all images into memory if possible
self.imgs = [cv2.imread(self.img_files[i]) for i in range(n)]
# Preload labels (required for weighted CE training)
self.labels = [np.zeros((0, 5))] * n
@@ -201,11 +201,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img_path = self.img_files[index]
label_path = self.label_files[index]
# if hasattr(self, 'imgs'): # preloaded
# img = self.imgs[index] # BGR
img = cv2.imread(img_path) # BGR
# Load image
if hasattr(self, 'imgs'): # preloaded
img = self.imgs[index]
else:
img = cv2.imread(img_path) # BGR
assert img is not None, 'File Not Found ' + img_path
# Augment colorspace
augment_hsv = True
if self.augment and augment_hsv:
# SV augmentation by 50%
+1 -1
View File
@@ -265,7 +265,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Compute losses
h = model.hyp # hyperparameters
bs = p[0].shape[0] # batch size
k = h['k'] * bs # loss gain
k = bs # loss gain
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tconf = torch.zeros_like(pi0[..., 0]) # conf