CreateML fixes (#906)

Support createML format
This commit is contained in:
s_teja 2022-07-05 22:48:10 +02:00 committed by GitHub
parent eb603c29e1
commit 5bc7fb9a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 5 deletions

View File

@ -1123,6 +1123,7 @@ class MainWindow(QMainWindow, WindowMixin):
u"<p>Make sure <i>%s</i> is a valid label file.")
% (e, unicode_file_path))
self.status("Error reading %s" % unicode_file_path)
return False
self.image_data = self.label_file.image_data
self.line_color = QColor(*self.label_file.lineColor)
@ -1156,7 +1157,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.paint_canvas()
self.add_recent_file(self.file_path)
self.toggle_actions(True)
self.show_bounding_box_from_annotation_file(file_path)
self.show_bounding_box_from_annotation_file(self.file_path)
counter = self.counter_str()
self.setWindowTitle(__appname__ + ' ' + file_path + ' ' + counter)
@ -1196,10 +1197,15 @@ class MainWindow(QMainWindow, WindowMixin):
else:
xml_path = os.path.splitext(file_path)[0] + XML_EXT
txt_path = os.path.splitext(file_path)[0] + TXT_EXT
json_path = os.path.splitext(file_path)[0] + JSON_EXT
if os.path.isfile(xml_path):
self.load_pascal_xml_by_filename(xml_path)
elif os.path.isfile(txt_path):
self.load_yolo_txt_by_filename(txt_path)
elif os.path.isfile(json_path):
self.load_create_ml_json_by_filename(json_path, file_path)
def resizeEvent(self, event):
if self.canvas and not self.image.isNull()\
@ -1300,10 +1306,13 @@ class MainWindow(QMainWindow, WindowMixin):
if dir_path is not None and len(dir_path) > 1:
self.default_save_dir = dir_path
self.show_bounding_box_from_annotation_file(self.file_path)
self.statusBar().showMessage('%s . Annotation will be saved to %s' %
('Change saved folder', self.default_save_dir))
self.statusBar().show()
def open_annotation_dialog(self, _value=False):
if self.file_path is None:
self.statusBar().showMessage('Please select image first')
@ -1320,6 +1329,17 @@ class MainWindow(QMainWindow, WindowMixin):
filename = filename[0]
self.load_pascal_xml_by_filename(filename)
elif self.label_file_format == LabelFileFormat.CREATE_ML:
filters = "Open Annotation JSON file (%s)" % ' '.join(['*.json'])
filename = ustr(QFileDialog.getOpenFileName(self, '%s - Choose a json file' % __appname__, path, filters))
if filename:
if isinstance(filename, (tuple, list)):
filename = filename[0]
self.load_create_ml_json_by_filename(filename, self.file_path)
def open_dir_dialog(self, _value=False, dir_path=None, silent=False):
if not self.may_continue():
return
@ -1337,6 +1357,9 @@ class MainWindow(QMainWindow, WindowMixin):
target_dir_path = ustr(default_open_dir_path)
self.last_open_dir = target_dir_path
self.import_dir_images(target_dir_path)
self.default_save_dir = target_dir_path
if self.file_path:
self.show_bounding_box_from_annotation_file(file_path=self.file_path)
def import_dir_images(self, dir_path):
if not self.may_continue() or not dir_path:

View File

@ -32,6 +32,7 @@ class CreateMLWriter:
output_image_dict = {
"image": self.filename,
"verified": self.verified,
"annotations": []
}
@ -107,12 +108,15 @@ class CreateMLReader:
with open(self.json_path, "r") as file:
input_data = file.read()
output_dict = json.loads(input_data)
self.verified = True
# Returns a list
output_list = json.loads(input_data)
if output_list:
self.verified = output_list[0].get("verified", False)
if len(self.shapes) > 0:
self.shapes = []
for image in output_dict:
for image in output_list:
if image["image"] == self.filename:
for shape in image["annotations"]:
self.add_shape(shape["label"], shape["coordinates"])

View File

@ -48,6 +48,7 @@ class LabelFile(object):
image_shape, shapes, filename, local_img_path=image_path)
writer.verified = self.verified
writer.write()
return
def save_pascal_voc_format(self, filename, shapes, image_path, image_data,

View File

@ -50,6 +50,8 @@ class TestCreateMLRW(unittest.TestCase):
writer = CreateMLWriter('tests', 'test.512.512.bmp', (512, 512, 1), shapes, output_file,
local_img_path='tests/test.512.512.bmp')
writer.verified = True
writer.write()
# check written json
@ -58,7 +60,7 @@ class TestCreateMLRW(unittest.TestCase):
import json
data_dict = json.loads(input_data)[0]
self.assertEqual(True, data_dict['verified'], 'verified tag not reflected')
self.assertEqual('test.512.512.bmp', data_dict['image'], 'filename not correct in .json')
self.assertEqual(2, len(data_dict['annotations']), 'output file contains to less annotations')
face = data_dict['annotations'][1]