updates
This commit is contained in:
+22
-15
@@ -1,35 +1,42 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parse_model_cfg(path):
|
||||
"""Parses the yolo-v3 layer configuration file and returns module definitions"""
|
||||
# Parses the yolo-v3 layer configuration file and returns module definitions
|
||||
file = open(path, 'r')
|
||||
lines = file.read().split('\n')
|
||||
lines = [x for x in lines if x and not x.startswith('#')]
|
||||
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
|
||||
module_defs = []
|
||||
mdefs = [] # module definitions
|
||||
for line in lines:
|
||||
if line.startswith('['): # This marks the start of a new block
|
||||
module_defs.append({})
|
||||
module_defs[-1]['type'] = line[1:-1].rstrip()
|
||||
if module_defs[-1]['type'] == 'convolutional':
|
||||
module_defs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
|
||||
mdefs.append({})
|
||||
mdefs[-1]['type'] = line[1:-1].rstrip()
|
||||
if mdefs[-1]['type'] == 'convolutional':
|
||||
mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
|
||||
else:
|
||||
key, value = line.split("=")
|
||||
value = value.strip()
|
||||
module_defs[-1][key.rstrip()] = value.strip()
|
||||
key, val = line.split("=")
|
||||
key = key.rstrip()
|
||||
|
||||
return module_defs
|
||||
if 'anchors' in key:
|
||||
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
|
||||
else:
|
||||
mdefs[-1][key] = val.strip()
|
||||
|
||||
return mdefs
|
||||
|
||||
|
||||
def parse_data_cfg(path):
|
||||
"""Parses the data configuration file"""
|
||||
# Parses the data configuration file
|
||||
options = dict()
|
||||
options['gpus'] = '0,1,2,3'
|
||||
options['num_workers'] = '10'
|
||||
with open(path, 'r') as fp:
|
||||
lines = fp.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line == '' or line.startswith('#'):
|
||||
continue
|
||||
key, value = line.split('=')
|
||||
options[key.strip()] = value.strip()
|
||||
key, val = line.split('=')
|
||||
options[key.strip()] = val.strip()
|
||||
|
||||
return options
|
||||
|
||||
@@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import torch_utils # , google_utils
|
||||
|
||||
Reference in New Issue
Block a user