From 1afde520d1c7920c327a6e1ea38ec6c944880c02 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 19 Dec 2020 19:01:15 -0800 Subject: [PATCH] Simplified PyTorch Hub loading of custom models (#1610) * Simplified PyTorch Hub loading of custom models * Update hubconf.py --- hubconf.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index a0b66dc1..df7796d4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -92,8 +92,29 @@ def yolov3_tiny(pretrained=False, channels=3, classes=80): return create('yolov3-tiny', pretrained, channels, classes) +def custom(path_or_model='path/to/model.pt'): + """YOLOv3-custom model from https://github.com/ultralytics/yolov3 + + Arguments (3 options): + path_or_model (str): 'path/to/model.pt' + path_or_model (dict): torch.load('path/to/model.pt') + path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] + Returns: + pytorch model + """ + model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint + if isinstance(model, dict): + model = model['model'] # load model + + hub_model = Model(model.yaml).to(next(model.parameters()).device) # create + hub_model.load_state_dict(model.float().state_dict()) # load state_dict + hub_model.names = model.names # class names + return hub_model + + if __name__ == '__main__': - model = create(name='yolov3', pretrained=True, channels=3, classes=80) # example + model = create(name='yolov3', pretrained=True, channels=3, classes=80) # pretrained example + # model = custom(path_or_model='path/to/model.pt') # custom example model = model.autoshape() # for PIL/cv2/np inputs and NMS # Verify inference