multi_gpu multi_scale
This commit is contained in:
+4
-12
@@ -90,7 +90,7 @@ class LoadWebcam: # for inference
|
||||
|
||||
|
||||
class LoadImagesAndLabels: # for training
|
||||
def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False):
|
||||
def __init__(self, path, batch_size=1, img_size=608, augment=False):
|
||||
with open(path, 'r') as file:
|
||||
self.img_files = file.readlines()
|
||||
self.img_files = [x.replace('\n', '') for x in self.img_files]
|
||||
@@ -102,8 +102,7 @@ class LoadImagesAndLabels: # for training
|
||||
self.nF = len(self.img_files) # number of image files
|
||||
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
||||
self.batch_size = batch_size
|
||||
self.height = img_size
|
||||
self.multi_scale = multi_scale
|
||||
self.img_size = img_size
|
||||
self.augment = augment
|
||||
|
||||
assert self.nF > 0, 'No images found in %s' % path
|
||||
@@ -121,13 +120,6 @@ class LoadImagesAndLabels: # for training
|
||||
ia = self.count * self.batch_size
|
||||
ib = min((self.count + 1) * self.batch_size, self.nF)
|
||||
|
||||
if self.multi_scale:
|
||||
# Multi-Scale YOLO Training
|
||||
height = random.choice(range(10, 20)) * 32 # 320 - 608 pixels
|
||||
else:
|
||||
# Fixed-Scale YOLO Training
|
||||
height = self.height
|
||||
|
||||
img_all, labels_all, img_paths, img_shapes = [], [], [], []
|
||||
for index, files_index in enumerate(range(ia, ib)):
|
||||
img_path = self.img_files[self.shuffled_vector[files_index]]
|
||||
@@ -159,7 +151,7 @@ class LoadImagesAndLabels: # for training
|
||||
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
|
||||
|
||||
h, w, _ = img.shape
|
||||
img, ratio, padw, padh = letterbox(img, height=height)
|
||||
img, ratio, padw, padh = letterbox(img, height=self.img_size)
|
||||
|
||||
# Load labels
|
||||
if os.path.isfile(label_path):
|
||||
@@ -189,7 +181,7 @@ class LoadImagesAndLabels: # for training
|
||||
nL = len(labels)
|
||||
if nL > 0:
|
||||
# convert xyxy to xywh
|
||||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
|
||||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / self.img_size
|
||||
|
||||
if self.augment:
|
||||
# random left-right flip
|
||||
|
||||
Reference in New Issue
Block a user