updates
This commit is contained in:
+52
-17
@@ -16,31 +16,61 @@ from utils.utils import xyxy2xywh
|
||||
|
||||
class LoadImages: # for inference
|
||||
def __init__(self, path, img_size=416):
|
||||
if os.path.isdir(path):
|
||||
image_format = ['.jpg', '.jpeg', '.png', '.tif']
|
||||
self.files = sorted(glob.glob('%s/*.*' % path))
|
||||
self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))
|
||||
elif os.path.isfile(path):
|
||||
self.files = [path]
|
||||
|
||||
self.nF = len(self.files) # number of image files
|
||||
self.height = img_size
|
||||
img_formats = ['.jpg', '.jpeg', '.png', '.tif']
|
||||
vid_formats = ['.mov', '.avi', '.mp4']
|
||||
|
||||
assert self.nF > 0, 'No images found in ' + path
|
||||
files = []
|
||||
if os.path.isdir(path):
|
||||
files = sorted(glob.glob('%s/*.*' % path))
|
||||
elif os.path.isfile(path):
|
||||
files = [path]
|
||||
|
||||
# self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in img_formats, files))
|
||||
images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
|
||||
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
|
||||
self.files = images + videos
|
||||
self.nI, self.nV = len(images), len(videos)
|
||||
self.nF = self.nI + self.nV # number of files
|
||||
self.video_flag = [False] * self.nI + [True] * self.nV
|
||||
self.mode = 'images'
|
||||
if any(videos):
|
||||
self.new_video(videos[0]) # new video
|
||||
else:
|
||||
self.cap = None
|
||||
assert self.nF > 0, 'No images or videos found in ' + path
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.count += 1
|
||||
if self.count == self.nF:
|
||||
raise StopIteration
|
||||
img_path = self.files[self.count]
|
||||
path = self.files[self.count]
|
||||
|
||||
# Read image
|
||||
img0 = cv2.imread(img_path) # BGR
|
||||
assert img0 is not None, 'File Not Found ' + img_path
|
||||
if self.video_flag[self.count]:
|
||||
self.mode = 'video'
|
||||
ret_val, img0 = self.cap.read()
|
||||
if not ret_val:
|
||||
self.count += 1
|
||||
self.cap.release()
|
||||
if self.count == self.nF: # last video
|
||||
raise StopIteration
|
||||
else:
|
||||
path = self.files[self.count]
|
||||
self.new_video(path)
|
||||
ret_val, img0 = self.cap.read()
|
||||
|
||||
self.frame += 1
|
||||
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nF, self.frame, self.nframes, path), end='')
|
||||
|
||||
else:
|
||||
# Read image
|
||||
self.count += 1
|
||||
img0 = cv2.imread(path) # BGR
|
||||
assert img0 is not None, 'File Not Found ' + path
|
||||
print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
|
||||
|
||||
# Padded resize
|
||||
img, _, _, _ = letterbox(img0, height=self.height)
|
||||
@@ -50,8 +80,13 @@ class LoadImages: # for inference
|
||||
img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32
|
||||
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
|
||||
# cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
|
||||
return img_path, img, img0
|
||||
# cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
|
||||
return path, img, img0, self.cap
|
||||
|
||||
def new_video(self, path):
|
||||
self.frame = 0
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
def __len__(self):
|
||||
return self.nF # number of files
|
||||
|
||||
+5
-2
@@ -163,15 +163,18 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
|
||||
|
||||
# Recall
|
||||
recall_curve = tpc / (n_gt + 1e-16)
|
||||
r.append(tpc[-1] / (n_gt + 1e-16))
|
||||
r.append(recall_curve[-1])
|
||||
|
||||
# Precision
|
||||
precision_curve = tpc / (tpc + fpc)
|
||||
p.append(tpc[-1] / (tpc[-1] + fpc[-1]))
|
||||
p.append(precision_curve[-1])
|
||||
|
||||
# AP from recall-precision curve
|
||||
ap.append(compute_ap(recall_curve, precision_curve))
|
||||
|
||||
# Plot
|
||||
# plt.plot(recall_curve, precision_curve)
|
||||
|
||||
return np.array(ap), unique_classes.astype('int32'), np.array(r), np.array(p)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user