This commit is contained in:
Glenn Jocher
2019-07-20 14:54:37 +02:00
parent cb30d60f4e
commit 4816969933
2 changed files with 32 additions and 21 deletions
+18 -14
View File
@@ -152,7 +152,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False):
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = [x for x in img_files if os.path.splitext(x)[-1].lower() in img_formats]
@@ -166,6 +166,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch = bi # batch index of image
self.img_size = img_size
self.augment = augment
self.hyp = hyp
self.image_weights = image_weights
self.rect = False if image_weights else rect
@@ -271,6 +272,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img_path = self.img_files[index]
label_path = self.label_files[index]
hyp = self.hyp
# Load image
img = self.imgs[index]
@@ -289,13 +291,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
augment_hsv = True
if self.augment and augment_hsv:
# SV augmentation by 50%
fraction = 0.50 # must be < 1.0
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # hue, sat, val
S = img_hsv[:, :, 1].astype(np.float32) # saturation
V = img_hsv[:, :, 2].astype(np.float32) # value
a = random.uniform(-1, 1) * fraction + 1
b = random.uniform(-1, 1) * fraction + 1
a = random.uniform(-1, 1) * hyp['hsv_s'] + 1
b = random.uniform(-1, 1) * hyp['hsv_v'] + 1
S *= a
V *= b
@@ -331,7 +332,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Augment image and labels
if self.augment:
img, labels = random_affine(img, labels, degrees=(-3, 3), translate=(0.05, 0.05), scale=(0.90, 1.10))
img, labels = random_affine(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'])
nL = len(labels) # number of labels
if nL:
@@ -410,8 +415,7 @@ def letterbox(img, new_shape=416, color=(128, 128, 128), mode='auto'):
return img, ratiow, ratioh, dw, dh
def random_affine(img, targets=(), degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2),
borderValue=(128, 128, 128)):
def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
# https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
@@ -423,24 +427,24 @@ def random_affine(img, targets=(), degrees=(-10, 10), translate=(.1, .1), scale=
# Rotation and Scale
R = np.eye(3)
a = random.uniform(degrees[0], degrees[1])
a = random.uniform(-degrees, degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(scale[0], scale[1])
s = random.uniform(1 - scale, 1 + scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(-1, 1) * translate[0] * img.shape[0] + border # x translation (pixels)
T[1, 2] = random.uniform(-1, 1) * translate[1] * img.shape[1] + border # y translation (pixels)
T[0, 2] = random.uniform(-translate, translate) * img.shape[0] + border # x translation (pixels)
T[1, 2] = random.uniform(-translate, translate) * img.shape[1] + border # y translation (pixels)
# Shear
S = np.eye(3)
S[0, 1] = math.tan(random.uniform(shear[0], shear[1]) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(shear[0], shear[1]) * math.pi / 180) # y shear (deg)
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
imw = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_AREA,
borderValue=borderValue) # BGR order borderValue
borderValue=(128, 128, 128)) # BGR order borderValue
# Return warped points also
if len(targets) > 0: