updates
This commit is contained in:
+9
-6
@@ -130,7 +130,7 @@ class LoadWebcam: # for inference
|
||||
|
||||
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=False):
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True):
|
||||
with open(path, 'r') as f:
|
||||
img_files = f.read().splitlines()
|
||||
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
|
||||
@@ -181,8 +181,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
# Preload images
|
||||
# if n < 200: # preload all images into memory if possible
|
||||
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
|
||||
if n < 1001: # preload all images into memory if possible
|
||||
self.imgs = [cv2.imread(self.img_files[i]) for i in range(n)]
|
||||
|
||||
# Preload labels (required for weighted CE training)
|
||||
self.labels = [np.zeros((0, 5))] * n
|
||||
@@ -201,11 +201,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
img_path = self.img_files[index]
|
||||
label_path = self.label_files[index]
|
||||
|
||||
# if hasattr(self, 'imgs'): # preloaded
|
||||
# img = self.imgs[index] # BGR
|
||||
img = cv2.imread(img_path) # BGR
|
||||
# Load image
|
||||
if hasattr(self, 'imgs'): # preloaded
|
||||
img = self.imgs[index]
|
||||
else:
|
||||
img = cv2.imread(img_path) # BGR
|
||||
assert img is not None, 'File Not Found ' + img_path
|
||||
|
||||
# Augment colorspace
|
||||
augment_hsv = True
|
||||
if self.augment and augment_hsv:
|
||||
# SV augmentation by 50%
|
||||
|
||||
+1
-1
@@ -265,7 +265,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||
# Compute losses
|
||||
h = model.hyp # hyperparameters
|
||||
bs = p[0].shape[0] # batch size
|
||||
k = h['k'] * bs # loss gain
|
||||
k = bs # loss gain
|
||||
for i, pi0 in enumerate(p): # layer i predictions, i
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
tconf = torch.zeros_like(pi0[..., 0]) # conf
|
||||
|
||||
Reference in New Issue
Block a user