From fa78fc4e3413eca7014450602781598c243fd95d Mon Sep 17 00:00:00 2001 From: tjiagoM Date: Thu, 2 Jul 2020 22:35:20 +0100 Subject: [PATCH] partial support for dropout layer (#1366) --- models.py | 6 ++++++ utils/parse_config.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index d5cfc02a..dd2fd656 100755 --- a/models.py +++ b/models.py @@ -106,6 +106,9 @@ def create_modules(module_defs, img_size, cfg): # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) try: j = layers[yolo_index] if 'from' in mdef else -1 + # If previous layer is a dropout layer, get the one before + if module_list[j].__class__.__name__ == 'Dropout': + j -= 1 bias_ = module_list[j][0].bias # shape(255,) bias = bias_[:modules.no * modules.na].view(modules.na, -1) # shape(3,85) bias[:, 4] += -4.5 # obj @@ -114,6 +117,9 @@ def create_modules(module_defs, img_size, cfg): except: print('WARNING: smart bias initialization failure.') + elif mdef['type'] == 'dropout': + perc = float(mdef['probability']) + modules = nn.Dropout(p=perc) else: print('Warning: Unrecognized Layer Type: ' + mdef['type']) diff --git a/utils/parse_config.py b/utils/parse_config.py index 4208748e..88d7d7ed 100644 --- a/utils/parse_config.py +++ b/utils/parse_config.py @@ -31,6 +31,7 @@ def parse_model_cfg(path): mdefs[-1][key] = [int(x) for x in val.split(',')] else: val = val.strip() + # TODO: .isnumeric() actually fails to get the float case if val.isnumeric(): # return int or float mdefs[-1][key] = int(val) if (int(val) - float(val)) == 0 else float(val) else: @@ -40,7 +41,7 @@ def parse_model_cfg(path): supported = ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups', 'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random', 'stride_x', 'stride_y', 'weights_type', 'weights_normalization', 'scale_x_y', 'beta_nms', 'nms_kind', - 'iou_loss', 'iou_normalizer', 'cls_normalizer', 'iou_thresh'] + 'iou_loss', 'iou_normalizer', 'cls_normalizer', 'iou_thresh', 'probability'] f = [] # fields for x in mdefs[1:]: