* YOLOv5 forward compatibility update * add data dir * ci test yolov3 * update build_targets() * update build_targets() * update build_targets() * update yolov3-spp.yaml * add yolov3-tiny.yaml * add yolov3-tiny.yaml * Update yolov3-tiny.yaml * thop bug fix * Detection() device bug fix * Use torchvision.ops.nms() * Remove redundant download mirror * CI tests with yolov3-tiny * Update README.md * Synch train and test iou_thresh * update requirements.txt * Cat apriori autolabels * Confusion matrix * Autosplit * Autosplit * Update README.md * AP no plot * Update caching * Update caching * Caching bug fix * --image-weights bug fix * datasets bug fix * mosaic plots bug fix * plot_study * boxes.max() * boxes.max() * boxes.max() * boxes.max() * boxes.max() * boxes.max() * update * Update README * Update README * Update README.md * Update README.md * results png * Update README * Targets scaling bug fix * update plot_study * update plot_study * update plot_study * update plot_study * Targets scaling bug fix * Finish Readme.md * Finish Readme.md * Finish Readme.md * Update README.md * Creado con Colaboratory
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
"""File for accessing YOLOv3 via PyTorch Hub https://pytorch.org/hub/
|
|
|
|
Usage:
|
|
import torch
|
|
model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True, channels=3, classes=80)
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from models.yolo import Model
|
|
from utils.general import set_logging
|
|
from utils.google_utils import attempt_download
|
|
|
|
dependencies = ['torch', 'yaml']
|
|
set_logging()
|
|
|
|
|
|
def create(name, pretrained, channels, classes):
|
|
"""Creates a specified YOLOv3 model
|
|
|
|
Arguments:
|
|
name (str): name of model, i.e. 'yolov3_spp'
|
|
pretrained (bool): load pretrained weights into the model
|
|
channels (int): number of input channels
|
|
classes (int): number of model classes
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
|
|
try:
|
|
model = Model(config, channels, classes)
|
|
if pretrained:
|
|
fname = f'{name}.pt' # checkpoint filename
|
|
attempt_download(fname) # download if not found locally
|
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
|
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
|
model.load_state_dict(state_dict, strict=False) # load
|
|
if len(ckpt['model'].names) == classes:
|
|
model.names = ckpt['model'].names # set class names attribute
|
|
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
|
|
return model
|
|
|
|
except Exception as e:
|
|
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
|
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
|
|
raise Exception(s) from e
|
|
|
|
|
|
def yolov3(pretrained=False, channels=3, classes=80):
|
|
"""YOLOv3 model from https://github.com/ultralytics/yolov3
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov3', pretrained, channels, classes)
|
|
|
|
|
|
def yolov3_spp(pretrained=False, channels=3, classes=80):
|
|
"""YOLOv3-SPP model from https://github.com/ultralytics/yolov3
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov3-spp', pretrained, channels, classes)
|
|
|
|
|
|
def yolov3_tiny(pretrained=False, channels=3, classes=80):
|
|
"""YOLOv3-tiny model from https://github.com/ultralytics/yolov3
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov3-tiny', pretrained, channels, classes)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = create(name='yolov3', pretrained=True, channels=3, classes=80) # example
|
|
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS
|
|
|
|
# Verify inference
|
|
from PIL import Image
|
|
|
|
imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
|
|
results = model(imgs)
|
|
results.show()
|
|
results.print()
|