greenhouse/utils/parse_config.py

61 lines
2.2 KiB
Python
Raw Normal View History

2019-12-10 18:04:24 -08:00
import os
2019-08-15 18:15:27 +02:00
import numpy as np
2019-02-12 16:58:07 +01:00
def parse_model_cfg(path):
2019-12-10 18:04:24 -08:00
# Parse the yolo *.cfg file and return module definitions path may be 'cfg/yolov3.cfg', 'yolov3.cfg', or 'yolov3'
if not path.endswith('.cfg'): # add .cfg suffix if omitted
path += '.cfg'
if not os.path.exists(path) and not path.startswith('cfg' + os.sep): # add cfg/ prefix if omitted
path = 'cfg' + os.sep + path
with open(path, 'r') as f:
lines = f.read().split('\n')
2018-08-26 10:51:39 +02:00
lines = [x for x in lines if x and not x.startswith('#')]
2019-02-11 18:15:51 +01:00
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
2019-08-15 18:15:27 +02:00
mdefs = [] # module definitions
2018-08-26 10:51:39 +02:00
for line in lines:
2019-02-11 18:15:51 +01:00
if line.startswith('['): # This marks the start of a new block
2019-08-15 18:15:27 +02:00
mdefs.append({})
mdefs[-1]['type'] = line[1:-1].rstrip()
if mdefs[-1]['type'] == 'convolutional':
mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
2018-08-26 10:51:39 +02:00
else:
2019-08-15 18:15:27 +02:00
key, val = line.split("=")
key = key.rstrip()
if 'anchors' in key:
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
else:
mdefs[-1][key] = val.strip()
2018-08-26 10:51:39 +02:00
2019-12-09 13:49:50 -08:00
# Check all fields are supported
supported = ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups',
'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random',
'stride_x', 'stride_y']
2019-12-09 15:54:46 -08:00
f = [] # fields
2019-12-09 13:49:50 -08:00
for x in mdefs[1:]:
[f.append(k) for k in x if k not in f]
2019-12-09 15:54:46 -08:00
u = [x for x in f if x not in supported] # unsupported fields
assert not any(u), "Unsupported fields %s in %s. See https://github.com/ultralytics/yolov3/issues/631" % (u, path)
2019-12-09 13:37:58 -08:00
2019-08-15 18:15:27 +02:00
return mdefs
2018-08-26 10:51:39 +02:00
2019-02-11 18:15:51 +01:00
2019-02-08 22:43:05 +01:00
def parse_data_cfg(path):
2019-08-15 18:15:27 +02:00
# Parses the data configuration file
2018-08-26 10:51:39 +02:00
options = dict()
with open(path, 'r') as fp:
lines = fp.readlines()
2019-08-15 18:15:27 +02:00
2018-08-26 10:51:39 +02:00
for line in lines:
line = line.strip()
if line == '' or line.startswith('#'):
continue
2019-08-15 18:15:27 +02:00
key, val = line.split('=')
options[key.strip()] = val.strip()
2018-08-26 10:51:39 +02:00
return options