YOLOv5 v5.0 release compatibility update for YOLOv3

This commit is contained in:
Glenn Jocher 2021-05-30 18:55:56 +02:00
parent 47ac6833ca
commit 4d0c2e6eee
38 changed files with 1192 additions and 528 deletions

View File

@ -113,7 +113,7 @@ $ python detect.py --source data/images --weights yolov3.pt --conf 0.25
### PyTorch Hub ### PyTorch Hub
To run **batched inference** with YOLOv5 and [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36): To run **batched inference** with YOLOv3 and [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36):
```python ```python
import torch import torch

55
data/GlobalWheat2020.yaml Normal file
View File

@ -0,0 +1,55 @@
# Global Wheat 2020 dataset http://www.global-wheat.com/
# Train command: python train.py --data GlobalWheat2020.yaml
# Default dataset location is next to YOLOv3:
# /parent_folder
# /datasets/GlobalWheat2020
# /yolov3
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: # 3422 images
- ../datasets/GlobalWheat2020/images/arvalis_1
- ../datasets/GlobalWheat2020/images/arvalis_2
- ../datasets/GlobalWheat2020/images/arvalis_3
- ../datasets/GlobalWheat2020/images/ethz_1
- ../datasets/GlobalWheat2020/images/rres_1
- ../datasets/GlobalWheat2020/images/inrae_1
- ../datasets/GlobalWheat2020/images/usask_1
val: # 748 images (WARNING: train set contains ethz_1)
- ../datasets/GlobalWheat2020/images/ethz_1
test: # 1276 images
- ../datasets/GlobalWheat2020/images/utokyo_1
- ../datasets/GlobalWheat2020/images/utokyo_2
- ../datasets/GlobalWheat2020/images/nau_1
- ../datasets/GlobalWheat2020/images/uq_1
# number of classes
nc: 1
# class names
names: [ 'wheat_head' ]
# download command/URL (optional) --------------------------------------------------------------------------------------
download: |
from utils.general import download, Path
# Download
dir = Path('../datasets/GlobalWheat2020') # dataset directory
urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
download(urls, dir=dir)
# Make Directories
for p in 'annotations', 'images', 'labels':
(dir / p).mkdir(parents=True, exist_ok=True)
# Move
for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
(dir / p).rename(dir / 'images' / p) # move to /images
f = (dir / p).with_suffix('.json') # json file
if f.exists():
f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations

52
data/SKU-110K.yaml Normal file
View File

@ -0,0 +1,52 @@
# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19
# Train command: python train.py --data SKU-110K.yaml
# Default dataset location is next to YOLOv3:
# /parent_folder
# /datasets/SKU-110K
# /yolov3
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../datasets/SKU-110K/train.txt # 8219 images
val: ../datasets/SKU-110K/val.txt # 588 images
test: ../datasets/SKU-110K/test.txt # 2936 images
# number of classes
nc: 1
# class names
names: [ 'object' ]
# download command/URL (optional) --------------------------------------------------------------------------------------
download: |
import shutil
from tqdm import tqdm
from utils.general import np, pd, Path, download, xyxy2xywh
# Download
datasets = Path('../datasets') # download directory
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
download(urls, dir=datasets, delete=False)
# Rename directories
dir = (datasets / 'SKU-110K')
if dir.exists():
shutil.rmtree(dir)
(datasets / 'SKU110K_fixed').rename(dir) # rename dir
(dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
# Convert labels
names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
images, unique_images = x[:, 0], np.unique(x[:, 0])
with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
f.writelines(f'./images/{s}\n' for s in unique_images)
for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
cls = 0 # single-class dataset
with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
for r in x[images == im]:
w, h = r[6], r[7] # image width, height
xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label

61
data/VisDrone.yaml Normal file
View File

@ -0,0 +1,61 @@
# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset
# Train command: python train.py --data VisDrone.yaml
# Default dataset location is next to YOLOv3:
# /parent_folder
# /VisDrone
# /yolov3
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../VisDrone/VisDrone2019-DET-train/images # 6471 images
val: ../VisDrone/VisDrone2019-DET-val/images # 548 images
test: ../VisDrone/VisDrone2019-DET-test-dev/images # 1610 images
# number of classes
nc: 10
# class names
names: [ 'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor' ]
# download command/URL (optional) --------------------------------------------------------------------------------------
download: |
from utils.general import download, os, Path
def visdrone2yolo(dir):
from PIL import Image
from tqdm import tqdm
def convert_box(size, box):
# Convert VisDrone box to YOLO xywh box
dw = 1. / size[0]
dh = 1. / size[1]
return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
(dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
for f in pbar:
img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
lines = []
with open(f, 'r') as file: # read annotation.txt
for row in [x.split(',') for x in file.read().strip().splitlines()]:
if row[4] == '0': # VisDrone 'ignored regions' class 0
continue
cls = int(row[5]) - 1
box = convert_box(img_size, tuple(map(int, row[:4])))
lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
with open(str(f).replace(os.sep + 'annotations' + os.sep, os.sep + 'labels' + os.sep), 'w') as fl:
fl.writelines(lines) # write label.txt
# Download
dir = Path('../VisDrone') # dataset directory
urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
download(urls, dir=dir)
# Convert
for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels

View File

@ -1,9 +1,9 @@
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
# Train command: python train.py --data argoverse_hd.yaml # Train command: python train.py --data argoverse_hd.yaml
# Default dataset location is next to /yolov5: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /argoverse # /argoverse
# /yolov5 # /yolov3
# download command/URL (optional) # download command/URL (optional)

View File

@ -1,6 +1,6 @@
# COCO 2017 dataset http://cocodataset.org # COCO 2017 dataset http://cocodataset.org
# Train command: python train.py --data coco.yaml # Train command: python train.py --data coco.yaml
# Default dataset location is next to /yolov3: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /coco # /coco
# /yolov3 # /yolov3
@ -30,6 +30,6 @@ names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', '
# Print classes # Print classes
# with open('data/coco.yaml') as f: # with open('data/coco.yaml') as f:
# d = yaml.load(f, Loader=yaml.FullLoader) # dict # d = yaml.safe_load(f) # dict
# for i, x in enumerate(d['names']): # for i, x in enumerate(d['names']):
# print(i, x) # print(i, x)

View File

@ -1,6 +1,6 @@
# COCO 2017 dataset http://cocodataset.org - first 128 training images # COCO 2017 dataset http://cocodataset.org - first 128 training images
# Train command: python train.py --data coco128.yaml # Train command: python train.py --data coco128.yaml
# Default dataset location is next to /yolov3: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /coco128 # /coco128
# /yolov3 # /yolov3

View File

@ -0,0 +1,28 @@
lr0: 0.00258
lrf: 0.17
momentum: 0.779
weight_decay: 0.00058
warmup_epochs: 1.33
warmup_momentum: 0.86
warmup_bias_lr: 0.0711
box: 0.0539
cls: 0.299
cls_pw: 0.825
obj: 0.632
obj_pw: 1.0
iou_t: 0.2
anchor_t: 3.44
anchors: 3.2
fl_gamma: 0.0
hsv_h: 0.0188
hsv_s: 0.704
hsv_v: 0.36
degrees: 0.0
translate: 0.0902
scale: 0.491
shear: 0.0
perspective: 0.0
flipud: 0.0
fliplr: 0.5
mosaic: 1.0
mixup: 0.0

102
data/objects365.yaml Normal file
View File

@ -0,0 +1,102 @@
# Objects365 dataset https://www.objects365.org/
# Train command: python train.py --data objects365.yaml
# Default dataset location is next to YOLOv3:
# /parent_folder
# /datasets/objects365
# /yolov3
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ../datasets/objects365/images/train # 1742289 images
val: ../datasets/objects365/images/val # 5570 images
# number of classes
nc: 365
# class names
names: [ 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', 'Bottle', 'Desk', 'Cup',
'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book',
'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag',
'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', 'Monitor/TV',
'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle',
'Stool', 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Basket', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird',
'High Heels', 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck',
'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning',
'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner', 'Knife',
'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork', 'Traffic Sign', 'Balloon', 'Tripod', 'Dog', 'Spoon', 'Clock',
'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', 'Other Fish',
'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', 'Fan',
'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard',
'Luggage', 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign',
'Dessert', 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard',
'Gun', 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry',
'Other Balls', 'Shovel', 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors',
'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', 'Zebra', 'Grape',
'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
'Billiards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette',
'Paint Brush', 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extension Cord', 'Tong', 'Tennis Racket',
'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine',
'Slide', 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine',
'Chicken', 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hot-air balloon',
'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', 'Blender', 'Peach', 'Rice', 'Wallet/Purse',
'Volleyball', 'Deer', 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball',
'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin',
'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', 'Sandwich', 'Nuts',
'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD',
'Pasta', 'Hammer', 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroom', 'Screwdriver', 'Soap', 'Recorder',
'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measure/Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips',
'Steak', 'Crosswalk Sign', 'Stapler', 'Camel', 'Formula 1', 'Pomegranate', 'Dishwasher', 'Crab',
'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
'Butterfly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', 'Hair Dryer', 'Egg tart',
'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French',
'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell',
'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Tennis paddle', 'Cosmetics Brush/Eyeliner Pencil',
'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis' ]
# download command/URL (optional) --------------------------------------------------------------------------------------
download: |
from pycocotools.coco import COCO
from tqdm import tqdm
from utils.general import download, Path
# Make Directories
dir = Path('../datasets/objects365') # dataset directory
for p in 'images', 'labels':
(dir / p).mkdir(parents=True, exist_ok=True)
for q in 'train', 'val':
(dir / p / q).mkdir(parents=True, exist_ok=True)
# Download
url = "https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/train/"
download([url + 'zhiyuan_objv2_train.tar.gz'], dir=dir, delete=False) # annotations json
download([url + f for f in [f'patch{i}.tar.gz' for i in range(51)]], dir=dir / 'images' / 'train',
curl=True, delete=False, threads=8)
# Move
train = dir / 'images' / 'train'
for f in tqdm(train.rglob('*.jpg'), desc=f'Moving images'):
f.rename(train / f.name) # move to /images/train
# Labels
coco = COCO(dir / 'zhiyuan_objv2_train.json')
names = [x["name"] for x in coco.loadCats(coco.getCatIds())]
for cid, cat in enumerate(names):
catIds = coco.getCatIds(catNms=[cat])
imgIds = coco.getImgIds(catIds=catIds)
for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'):
width, height = im["width"], im["height"]
path = Path(im["file_name"]) # image filename
try:
with open(dir / 'labels' / 'train' / path.with_suffix('.txt').name, 'a') as file:
annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None)
for a in coco.loadAnns(annIds):
x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner)
x, y = x + w / 2, y + h / 2 # xy to center
file.write(f"{cid} {x / width:.5f} {y / height:.5f} {w / width:.5f} {h / height:.5f}\n")
except Exception as e:
print(e)

View File

@ -2,10 +2,10 @@
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
# Download command: bash data/scripts/get_argoverse_hd.sh # Download command: bash data/scripts/get_argoverse_hd.sh
# Train command: python train.py --data argoverse_hd.yaml # Train command: python train.py --data argoverse_hd.yaml
# Default dataset location is next to /yolov5: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /argoverse # /argoverse
# /yolov5 # /yolov3
# Download/unzip images # Download/unzip images
d='../argoverse/' # unzip directory d='../argoverse/' # unzip directory
@ -25,7 +25,7 @@ import json
from pathlib import Path from pathlib import Path
annotation_files = ["train.json", "val.json"] annotation_files = ["train.json", "val.json"]
print("Converting annotations to YOLOv5 format...") print("Converting annotations to YOLOv3 format...")
for val in annotation_files: for val in annotation_files:
a = json.load(open(val, "rb")) a = json.load(open(val, "rb"))
@ -36,7 +36,7 @@ for val in annotation_files:
img_name = a['images'][img_id]['name'] img_name = a['images'][img_id]['name']
img_label_name = img_name[:-3] + "txt" img_label_name = img_name[:-3] + "txt"
obj_class = annot['category_id'] cls = annot['category_id'] # instance class id
x_center, y_center, width, height = annot['bbox'] x_center, y_center, width, height = annot['bbox']
x_center = (x_center + width / 2) / 1920. # offset and scale x_center = (x_center + width / 2) / 1920. # offset and scale
y_center = (y_center + height / 2) / 1200. # offset and scale y_center = (y_center + height / 2) / 1200. # offset and scale
@ -46,11 +46,10 @@ for val in annotation_files:
img_dir = "./labels/" + a['seq_dirs'][a['images'][annot['image_id']]['sid']] img_dir = "./labels/" + a['seq_dirs'][a['images'][annot['image_id']]['sid']]
Path(img_dir).mkdir(parents=True, exist_ok=True) Path(img_dir).mkdir(parents=True, exist_ok=True)
if img_dir + "/" + img_label_name not in label_dict: if img_dir + "/" + img_label_name not in label_dict:
label_dict[img_dir + "/" + img_label_name] = [] label_dict[img_dir + "/" + img_label_name] = []
label_dict[img_dir + "/" + img_label_name].append(f"{obj_class} {x_center} {y_center} {width} {height}\n") label_dict[img_dir + "/" + img_label_name].append(f"{cls} {x_center} {y_center} {width} {height}\n")
for filename in label_dict: for filename in label_dict:
with open(filename, "w") as file: with open(filename, "w") as file:

View File

@ -2,7 +2,7 @@
# COCO 2017 dataset http://cocodataset.org # COCO 2017 dataset http://cocodataset.org
# Download command: bash data/scripts/get_coco.sh # Download command: bash data/scripts/get_coco.sh
# Train command: python train.py --data coco.yaml # Train command: python train.py --data coco.yaml
# Default dataset location is next to /yolov3: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /coco # /coco
# /yolov3 # /yolov3

View File

@ -0,0 +1,17 @@
#!/bin/bash
# COCO128 dataset https://www.kaggle.com/ultralytics/coco128
# Download command: bash data/scripts/get_coco128.sh
# Train command: python train.py --data coco128.yaml
# Default dataset location is next to YOLOv3:
# /parent_folder
# /coco128
# /yolov3
# Download/unzip images and labels
d='../' # unzip directory
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
f='coco128.zip' # or 'coco2017labels-segments.zip', 68 MB
echo 'Downloading' $url$f ' ...'
curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background
wait # finish background tasks

View File

@ -2,10 +2,10 @@
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/ # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
# Download command: bash data/scripts/get_voc.sh # Download command: bash data/scripts/get_voc.sh
# Train command: python train.py --data voc.yaml # Train command: python train.py --data voc.yaml
# Default dataset location is next to /yolov5: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /VOC # /VOC
# /yolov5 # /yolov3
start=$(date +%s) start=$(date +%s)
mkdir -p ../tmp mkdir -p ../tmp
@ -29,34 +29,27 @@ echo "Completed in" $runtime "seconds"
echo "Splitting dataset..." echo "Splitting dataset..."
python3 - "$@" <<END python3 - "$@" <<END
import xml.etree.ElementTree as ET
import pickle
import os import os
from os import listdir, getcwd import xml.etree.ElementTree as ET
from os.path import join from os import getcwd
sets=[('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test')] sets = [('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog",
"horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
def convert(size, box): def convert_box(size, box):
dw = 1./(size[0]) dw = 1. / (size[0])
dh = 1./(size[1]) dh = 1. / (size[1])
x = (box[0] + box[1])/2.0 - 1 x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
y = (box[2] + box[3])/2.0 - 1 return x * dw, y * dh, w * dw, h * dh
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(year, image_id): def convert_annotation(year, image_id):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id)) in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml' % (year, image_id))
out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w') out_file = open('VOCdevkit/VOC%s/labels/%s.txt' % (year, image_id), 'w')
tree=ET.parse(in_file) tree = ET.parse(in_file)
root = tree.getroot() root = tree.getroot()
size = root.find('size') size = root.find('size')
w = int(size.find('width').text) w = int(size.find('width').text)
@ -65,74 +58,58 @@ def convert_annotation(year, image_id):
for obj in root.iter('object'): for obj in root.iter('object'):
difficult = obj.find('difficult').text difficult = obj.find('difficult').text
cls = obj.find('name').text cls = obj.find('name').text
if cls not in classes or int(difficult)==1: if cls not in classes or int(difficult) == 1:
continue continue
cls_id = classes.index(cls) cls_id = classes.index(cls)
xmlbox = obj.find('bndbox') xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
bb = convert((w,h), b) float(xmlbox.find('ymax').text))
bb = convert_box((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
wd = getcwd()
cwd = getcwd()
for year, image_set in sets: for year, image_set in sets:
if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)): if not os.path.exists('VOCdevkit/VOC%s/labels/' % year):
os.makedirs('VOCdevkit/VOC%s/labels/'%(year)) os.makedirs('VOCdevkit/VOC%s/labels/' % year)
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split() image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt' % (year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w') list_file = open('%s_%s.txt' % (year, image_set), 'w')
for image_id in image_ids: for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id)) list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n' % (cwd, year, image_id))
convert_annotation(year, image_id) convert_annotation(year, image_id)
list_file.close() list_file.close()
END END
cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt >train.txt cat 2007_train.txt 2007_val.txt 2012_train.txt 2012_val.txt >train.txt
cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt >train.all.txt cat 2007_train.txt 2007_val.txt 2007_test.txt 2012_train.txt 2012_val.txt >train.all.txt
mkdir ../VOC ../VOC/images ../VOC/images/train ../VOC/images/val
mkdir ../VOC/labels ../VOC/labels/train ../VOC/labels/val
python3 - "$@" <<END python3 - "$@" <<END
import shutil
import os import os
os.system('mkdir ../VOC/')
os.system('mkdir ../VOC/images')
os.system('mkdir ../VOC/images/train')
os.system('mkdir ../VOC/images/val')
os.system('mkdir ../VOC/labels')
os.system('mkdir ../VOC/labels/train')
os.system('mkdir ../VOC/labels/val')
import os
print(os.path.exists('../tmp/train.txt')) print(os.path.exists('../tmp/train.txt'))
f = open('../tmp/train.txt', 'r') with open('../tmp/train.txt', 'r') as f:
lines = f.readlines() for line in f.readlines():
line = "/".join(line.split('/')[-5:]).strip()
for line in lines: if os.path.exists("../" + line):
line = "/".join(line.split('/')[-5:]).strip() os.system("cp ../" + line + " ../VOC/images/train")
if (os.path.exists("../" + line)):
os.system("cp ../"+ line + " ../VOC/images/train")
line = line.replace('JPEGImages', 'labels')
line = line.replace('jpg', 'txt')
if (os.path.exists("../" + line)):
os.system("cp ../"+ line + " ../VOC/labels/train")
line = line.replace('JPEGImages', 'labels').replace('jpg', 'txt')
if os.path.exists("../" + line):
os.system("cp ../" + line + " ../VOC/labels/train")
print(os.path.exists('../tmp/2007_test.txt')) print(os.path.exists('../tmp/2007_test.txt'))
f = open('../tmp/2007_test.txt', 'r') with open('../tmp/2007_test.txt', 'r') as f:
lines = f.readlines() for line in f.readlines():
line = "/".join(line.split('/')[-5:]).strip()
for line in lines: if os.path.exists("../" + line):
line = "/".join(line.split('/')[-5:]).strip() os.system("cp ../" + line + " ../VOC/images/val")
if (os.path.exists("../" + line)):
os.system("cp ../"+ line + " ../VOC/images/val")
line = line.replace('JPEGImages', 'labels')
line = line.replace('jpg', 'txt')
if (os.path.exists("../" + line)):
os.system("cp ../"+ line + " ../VOC/labels/val")
line = line.replace('JPEGImages', 'labels').replace('jpg', 'txt')
if os.path.exists("../" + line):
os.system("cp ../" + line + " ../VOC/labels/val")
END END
rm -rf ../tmp # remove temporary directory rm -rf ../tmp # remove temporary directory

View File

@ -1,6 +1,6 @@
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/ # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
# Train command: python train.py --data voc.yaml # Train command: python train.py --data voc.yaml
# Default dataset location is next to /yolov3: # Default dataset location is next to YOLOv3:
# /parent_folder # /parent_folder
# /VOC # /VOC
# /yolov3 # /yolov3

View File

@ -5,24 +5,24 @@ from pathlib import Path
import cv2 import cv2
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
def detect(save_img=False): @torch.no_grad()
def detect(opt):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images save_img = not opt.nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://')) ('rtsp://', 'rtmp://', 'http://', 'https://'))
# Directories # Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Initialize # Initialize
@ -34,6 +34,7 @@ def detect(save_img=False):
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size imgsz = check_img_size(imgsz, s=stride) # check img_size
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half: if half:
model.half() # to FP16 model.half() # to FP16
@ -52,10 +53,6 @@ def detect(save_img=False):
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride) dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
# Run inference # Run inference
if device.type != 'cpu': if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
@ -72,7 +69,8 @@ def detect(save_img=False):
pred = model(img, augment=opt.augment)[0] pred = model(img, augment=opt.augment)[0]
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms,
max_det=opt.max_det)
t2 = time_synchronized() t2 = time_synchronized()
# Apply Classifier # Apply Classifier
@ -82,15 +80,16 @@ def detect(save_img=False):
# Process detections # Process detections
for i, det in enumerate(pred): # detections per image for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1 if webcam: # batch_size >= 1
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
else: else:
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if opt.save_crop else im0 # for opt.save_crop
if len(det): if len(det):
# Rescale boxes from img_size to im0 size # Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
@ -108,9 +107,12 @@ def detect(save_img=False):
with open(txt_path + '.txt', 'a') as f: with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # Add bbox to image if save_img or opt.save_crop or view_img: # Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}' c = int(cls) # integer class
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
if opt.save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
# Print time (inference + NMS) # Print time (inference + NMS)
print(f'{s}Done. ({t2 - t1:.3f}s)') print(f'{s}Done. ({t2 - t1:.3f}s)')
@ -153,10 +155,12 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--max-det', type=int, default=1000, help='maximum number of detections per image')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
@ -165,14 +169,16 @@ if __name__ == '__main__':
parser.add_argument('--project', default='runs/detect', help='save results to project/name') parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name') parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
check_requirements(exclude=('pycocotools', 'thop')) check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
with torch.no_grad(): if opt.update: # update all models (to fix SourceChangeWarning)
if opt.update: # update all models (to fix SourceChangeWarning) for opt.weights in ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt']:
for opt.weights in ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt']: detect(opt=opt)
detect() strip_optimizer(opt.weights)
strip_optimizer(opt.weights) else:
else: detect(opt=opt)
detect()

View File

@ -2,24 +2,13 @@
Usage: Usage:
import torch import torch
model = torch.hub.load('ultralytics/yolov3', 'yolov3tiny') model = torch.hub.load('ultralytics/yolov3', 'yolov3_tiny')
""" """
from pathlib import Path
import torch import torch
from models.yolo import Model
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device
dependencies = ['torch', 'yaml'] def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop'))
set_logging()
def create(name, pretrained, channels, classes, autoshape):
"""Creates a specified YOLOv3 model """Creates a specified YOLOv3 model
Arguments: Arguments:
@ -27,85 +16,81 @@ def create(name, pretrained, channels, classes, autoshape):
pretrained (bool): load pretrained weights into the model pretrained (bool): load pretrained weights into the model
channels (int): number of input channels channels (int): number of input channels
classes (int): number of model classes classes (int): number of model classes
autoshape (bool): apply YOLOv3 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns: Returns:
pytorch model YOLOv3 pytorch model
""" """
from pathlib import Path
from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))
set_logging(verbose=verbose)
fname = Path(name).with_suffix('.pt') # checkpoint filename
try: try:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path if pretrained and channels == 3 and classes == 80:
model = Model(cfg, channels, classes) model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
if pretrained: else:
fname = f'{name}.pt' # checkpoint filename cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
attempt_download(fname) # download if not found locally model = Model(cfg, channels, classes) # create model
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load if pretrained:
msd = model.state_dict() # model state_dict ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 msd = model.state_dict() # model state_dict
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
model.load_state_dict(csd, strict=False) # load csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
if len(ckpt['model'].names) == classes: model.load_state_dict(csd, strict=False) # load
model.names = ckpt['model'].names # set class names attribute if len(ckpt['model'].names) == classes:
if autoshape: model.names = ckpt['model'].names # set class names attribute
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS if autoshape:
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device) return model.to(device)
except Exception as e: except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36' help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e raise Exception(s) from e
def custom(path_or_model='path/to/model.pt', autoshape=True): def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
"""YOLOv3-custom model https://github.com/ultralytics/yolov3 # YOLOv3 custom or local model
return _create(path, autoshape=autoshape, verbose=verbose, device=device)
Arguments (3 options):
path_or_model (str): 'path/to/model.pt'
path_or_model (dict): torch.load('path/to/model.pt')
path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
Returns:
pytorch model
"""
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
if isinstance(model, dict):
model = model['ema' if model.get('ema') else 'model'] # load model
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
hub_model.names = model.names # class names
if autoshape:
hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
return hub_model.to(device)
def yolov3(pretrained=True, channels=3, classes=80, autoshape=True): def yolov3(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv3 model https://github.com/ultralytics/yolov3 # YOLOv3 model https://github.com/ultralytics/yolov3
return create('yolov3', pretrained, channels, classes, autoshape) return _create('yolov3', pretrained, channels, classes, autoshape, verbose, device)
def yolov3_spp(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
def yolov3_spp(pretrained=True, channels=3, classes=80, autoshape=True):
# YOLOv3-SPP model https://github.com/ultralytics/yolov3 # YOLOv3-SPP model https://github.com/ultralytics/yolov3
return create('yolov3-spp', pretrained, channels, classes, autoshape) return _create('yolov3-spp', pretrained, channels, classes, autoshape, verbose, device)
def yolov3_tiny(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
def yolov3_tiny(pretrained=True, channels=3, classes=80, autoshape=True):
# YOLOv3-tiny model https://github.com/ultralytics/yolov3 # YOLOv3-tiny model https://github.com/ultralytics/yolov3
return create('yolov3-tiny', pretrained, channels, classes, autoshape) return _create('yolov3-tiny', pretrained, channels, classes, autoshape, verbose, device)
if __name__ == '__main__': if __name__ == '__main__':
model = create(name='yolov3', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example model = _create(name='yolov3', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
# model = custom(path_or_model='path/to/model.pt') # custom example # model = custom(path='path/to/model.pt') # custom
# Verify inference # Verify inference
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
imgs = [Image.open('data/images/bus.jpg'), # PIL imgs = ['data/images/zidane.jpg', # filename
'data/images/zidane.jpg', # filename 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg', # URI
'https://github.com/ultralytics/yolov3/raw/master/data/images/bus.jpg', # URI cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
np.zeros((640, 480, 3))] # numpy Image.open('data/images/bus.jpg'), # PIL
np.zeros((320, 640, 3))] # numpy
results = model(imgs) # batched inference results = model(imgs) # batched inference
results.print() results.print()

View File

@ -13,8 +13,8 @@ from PIL import Image
from torch.cuda import amp from torch.cuda import amp
from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import color_list, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized from utils.torch_utils import time_synchronized
@ -215,32 +215,34 @@ class NMS(nn.Module):
conf = 0.25 # confidence threshold conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class classes = None # (optional list) filter by class
max_det = 1000 # maximum number of detections per image
def __init__(self): def __init__(self):
super(NMS, self).__init__() super(NMS, self).__init__()
def forward(self, x): def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)
class autoShape(nn.Module): class AutoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class classes = None # (optional list) filter by class
max_det = 1000 # maximum number of detections per image
def __init__(self, model): def __init__(self, model):
super(autoShape, self).__init__() super(AutoShape, self).__init__()
self.model = model.eval() self.model = model.eval()
def autoshape(self): def autoshape(self):
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self return self
@torch.no_grad() @torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are: # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg' # filename: imgs = 'data/images/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3) # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
@ -271,7 +273,7 @@ class autoShape(nn.Module):
shape0.append(s) # image shape shape0.append(s) # image shape
g = (size / max(s)) # gain g = (size / max(s)) # gain
shape1.append([y * g for y in s]) shape1.append([y * g for y in s])
imgs[i] = im # update imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.stack(x, 0) if n > 1 else x[0][None] # stack
@ -285,7 +287,7 @@ class autoShape(nn.Module):
t.append(time_synchronized()) t.append(time_synchronized())
# Post-process # Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
for i in range(n): for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i]) scale_coords(shape1, y[i][:, :4], shape0[i])
@ -311,29 +313,32 @@ class Detections:
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape self.s = shape # inference BCHW shape
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
colors = color_list() for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
if pred is not None: if pred is not None:
for c in pred[:, -1].unique(): for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class n = (pred[:, -1] == c).sum() # detections per class
str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
if show or save or render: if show or save or render or crop:
for *box, conf, cls in pred: # xyxy, confidence, class for *box, conf, cls in pred: # xyxy, confidence, class
label = f'{self.names[int(cls)]} {conf:.2f}' label = f'{self.names[int(cls)]} {conf:.2f}'
plot_one_box(box, img, label=label, color=colors[int(cls) % 10]) if crop:
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
else: # all others
plot_one_box(box, im, label=label, color=colors(cls))
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint: if pprint:
print(str.rstrip(', ')) print(str.rstrip(', '))
if show: if show:
img.show(self.files[i]) # show im.show(self.files[i]) # show
if save: if save:
f = self.files[i] f = self.files[i]
img.save(Path(save_dir) / f) # save im.save(save_dir / f) # save
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n') print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
if render: if render:
self.imgs[i] = np.asarray(img) self.imgs[i] = np.asarray(im)
def print(self): def print(self):
self.display(pprint=True) # print results self.display(pprint=True) # print results
@ -343,10 +348,14 @@ class Detections:
self.display(show=True) # show results self.display(show=True) # show results
def save(self, save_dir='runs/hub/exp'): def save(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
Path(save_dir).mkdir(parents=True, exist_ok=True)
self.display(save=True, save_dir=save_dir) # save results self.display(save=True, save_dir=save_dir) # save results
def crop(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
self.display(crop=True, save_dir=save_dir) # crop results
print(f'Saved results to {save_dir}\n')
def render(self): def render(self):
self.display(render=True) # render results self.display(render=True) # render results
return self.imgs return self.imgs

View File

@ -110,25 +110,27 @@ class Ensemble(nn.ModuleList):
return y, None # inference, train output return y, None # inference, train output
def attempt_load(weights, map_location=None): def attempt_load(weights, map_location=None, inplace=True):
from models.yolo import Detect, Model
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w) ckpt = torch.load(attempt_download(w), map_location=map_location) # load
ckpt = torch.load(w, map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
# Compatibility updates # Compatibility updates
for m in model.modules(): for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
m.inplace = True # pytorch 1.7.0 compatibility m.inplace = inplace # pytorch 1.7.0 compatibility
elif type(m) is Conv: elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if len(model) == 1: if len(model) == 1:
return model[-1] # return model return model[-1] # return model
else: else:
print('Ensemble created with %s\n' % weights) print(f'Ensemble created with {weights}\n')
for k in ['names', 'stride']: for k in ['names']:
setattr(model, k, getattr(model[-1], k)) setattr(model, k, getattr(model[-1], k))
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
return model # return ensemble return model # return ensemble

View File

@ -1,34 +1,43 @@
"""Exports a YOLOv3 *.pt model to ONNX and TorchScript formats """Exports a YOLOv3 *.pt model to TorchScript, ONNX, CoreML formats
Usage: Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov3.pt --img 640 --batch 1 $ python path/to/models/export.py --weights yolov3.pt --img 640 --batch 1
""" """
import argparse import argparse
import sys import sys
import time import time
from pathlib import Path
sys.path.append('./') # to run '$ python *.py' files in subdirectories sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile
import models import models
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
from utils.torch_utils import select_device from utils.torch_utils import select_device
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov3.pt', help='weights path') # from yolov3/models/ parser.add_argument('--weights', type=str, default='./yolov3.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set YOLOv3 Detect() inplace=True')
parser.add_argument('--train', action='store_true', help='model.train() mode')
parser.add_argument('--optimize', action='store_true', help='optimize TorchScript for mobile') # TorchScript-only
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
parser.add_argument('--opset-version', type=int, default=12, help='ONNX opset version') # ONNX-only
opt = parser.parse_args() opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
opt.include = [x.lower() for x in opt.include]
print(opt) print(opt)
set_logging() set_logging()
t = time.time() t = time.time()
@ -41,11 +50,16 @@ if __name__ == '__main__':
# Checks # Checks
gs = int(max(model.stride)) # grid size (max stride) gs = int(max(model.stride)) # grid size (max stride)
opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0'
# Input # Input
img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
# Update model # Update model
if opt.half:
img, model = img.half(), model.half() # to FP16
if opt.train:
model.train() # training mode (no grid construction in Detect layer)
for k, m in model.named_modules(): for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations if isinstance(m, models.common.Conv): # assign export-friendly activations
@ -53,52 +67,79 @@ if __name__ == '__main__':
m.act = Hardswish() m.act = Hardswish()
elif isinstance(m.act, nn.SiLU): elif isinstance(m.act, nn.SiLU):
m.act = SiLU() m.act = SiLU()
# elif isinstance(m, models.yolo.Detect): elif isinstance(m, models.yolo.Detect):
# m.forward = m.forward_export # assign forward (optional) m.inplace = opt.inplace
model.model[-1].export = not opt.grid # set Detect() layer grid export m.onnx_dynamic = opt.dynamic
y = model(img) # dry run # m.forward = m.forward_export # assign forward (optional)
# TorchScript export for _ in range(2):
try: y = model(img) # dry runs
print('\nStarting TorchScript export with torch %s...' % torch.__version__) print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
print('TorchScript export success, saved as %s' % f)
except Exception as e:
print('TorchScript export failure: %s' % e)
# ONNX export # TorchScript export -----------------------------------------------------------------------------------------------
try: if 'torchscript' in opt.include or 'coreml' in opt.include:
import onnx prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if opt.optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
print('\nStarting ONNX export with onnx %s...' % onnx.__version__) # ONNX export ------------------------------------------------------------------------------------------------------
f = opt.weights.replace('.pt', '.onnx') # filename if 'onnx' in opt.include:
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], prefix = colorstr('ONNX:')
output_names=['classes', 'boxes'] if y is None else ['output'], try:
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) import onnx
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
# Checks print(f'{prefix} starting export with onnx {onnx.__version__}...')
onnx_model = onnx.load(f) # load onnx model f = opt.weights.replace('.pt', '.onnx') # filename
onnx.checker.check_model(onnx_model) # check onnx model torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
print('ONNX export success, saved as %s' % f) do_constant_folding=not opt.train,
except Exception as e: dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
print('ONNX export failure: %s' % e) 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
# CoreML export # Checks
try: model_onnx = onnx.load(f) # load onnx model
import coremltools as ct onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print
print('\nStarting CoreML export with coremltools %s...' % ct.__version__) # Simplify
# convert model from torchscript and apply pixel scaling as per detect.py if opt.simplify:
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) try:
f = opt.weights.replace('.pt', '.mlmodel') # filename check_requirements(['onnx-simplifier'])
model.save(f) import onnxsim
print('CoreML export success, saved as %s' % f)
except Exception as e: print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
print('CoreML export failure: %s' % e) model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=opt.dynamic,
input_shapes={'images': list(img.shape)} if opt.dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
# CoreML export ----------------------------------------------------------------------------------------------------
if 'coreml' in opt.include:
prefix = colorstr('CoreML:')
try:
import coremltools as ct
print(f'{prefix} starting export with coremltools {ct.__version__}...')
assert opt.train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
# Finish # Finish
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')

View File

@ -1,11 +1,16 @@
# YOLOv3 YOLO-specific modules """YOLOv3-specific modules
Usage:
$ python path/to/models/yolo.py --cfg yolov3.yaml
"""
import argparse import argparse
import logging import logging
import sys import sys
from copy import deepcopy from copy import deepcopy
from pathlib import Path
sys.path.append('./') # to run '$ python *.py' files in subdirectories sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from models.common import * from models.common import *
@ -23,9 +28,9 @@ except ImportError:
class Detect(nn.Module): class Detect(nn.Module):
stride = None # strides computed during build stride = None # strides computed during build
export = False # onnx export onnx_dynamic = False # ONNX export parameter
def __init__(self, nc=80, anchors=(), ch=()): # detection layer def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(Detect, self).__init__() super(Detect, self).__init__()
self.nc = nc # number of classes self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor self.no = nc + 5 # number of outputs per anchor
@ -36,23 +41,28 @@ class Detect(nn.Module):
self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
def forward(self, x): def forward(self, x):
# x = x.copy() # for profiling # x = x.copy() # for profiling
z = [] # inference output z = [] # inference output
self.training |= self.export
for i in range(self.nl): for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]: if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device) self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid() y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy if self.inplace:
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no)) z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x) return x if self.training else (torch.cat(z, 1), x)
@ -72,7 +82,7 @@ class Model(nn.Module):
import yaml # for torch hub import yaml # for torch hub
self.yaml_file = Path(cfg).name self.yaml_file = Path(cfg).name
with open(cfg) as f: with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict self.yaml = yaml.safe_load(f) # model dict
# Define model # Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
@ -84,18 +94,20 @@ class Model(nn.Module):
self.yaml['anchors'] = round(anchors) # override yaml value self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) self.inplace = self.yaml.get('inplace', True)
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# Build strides, anchors # Build strides, anchors
m = self.model[-1] # Detect() m = self.model[-1] # Detect()
if isinstance(m, Detect): if isinstance(m, Detect):
s = 256 # 2x min stride s = 256 # 2x min stride
m.inplace = self.inplace
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1) m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m) check_anchor_order(m)
self.stride = m.stride self.stride = m.stride
self._initialize_biases() # only run once self._initialize_biases() # only run once
# print('Strides: %s' % m.stride.tolist()) # logger.info('Strides: %s' % m.stride.tolist())
# Init weights, biases # Init weights, biases
initialize_weights(self) initialize_weights(self)
@ -104,24 +116,23 @@ class Model(nn.Module):
def forward(self, x, augment=False, profile=False): def forward(self, x, augment=False, profile=False):
if augment: if augment:
img_size = x.shape[-2:] # height, width return self.forward_augment(x) # augmented inference, None
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi[..., :4] /= si # de-scale
if fi == 2:
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
elif fi == 3:
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
else: else:
return self.forward_once(x, profile) # single-scale inference, train return self.forward_once(x, profile) # single-scale inference, train
def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
def forward_once(self, x, profile=False): def forward_once(self, x, profile=False):
y, dt = [], [] # outputs y, dt = [], [] # outputs
for m in self.model: for m in self.model:
@ -134,15 +145,34 @@ class Model(nn.Module):
for _ in range(10): for _ in range(10):
_ = m(x) _ = m(x)
dt.append((time_synchronized() - t) * 100) dt.append((time_synchronized() - t) * 100)
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type)) if m == self.model[0]:
logger.info(f"{'time (ms)':>10s} {'GFLOPS':>10s} {'params':>10s} {'module'}")
logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
x = m(x) # run x = m(x) # run
y.append(x if m.i in self.save else None) # save output y.append(x if m.i in self.save else None) # save output
if profile: if profile:
print('%.1fms total' % sum(dt)) logger.info('%.1fms total' % sum(dt))
return x return x
def _descale_pred(self, p, flips, scale, img_size):
# de-scale predictions following augmented inference (inverse operation)
if self.inplace:
p[..., :4] /= scale # de-scale
if flips == 2:
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
elif flips == 3:
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
else:
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
if flips == 2:
y = img_size[0] - y # de-flip ud
elif flips == 3:
x = img_size[1] - x # de-flip lr
p = torch.cat((x, y, wh, p[..., 4:]), -1)
return p
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3 # https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
@ -157,15 +187,16 @@ class Model(nn.Module):
m = self.model[-1] # Detect() module m = self.model[-1] # Detect() module
for mi in m.m: # from for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) logger.info(
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
# def _print_weights(self): # def _print_weights(self):
# for m in self.model.modules(): # for m in self.model.modules():
# if type(m) is Bottleneck: # if type(m) is Bottleneck:
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights # logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers... ') logger.info('Fusing layers... ')
for m in self.model.modules(): for m in self.model.modules():
if type(m) is Conv and hasattr(m, 'bn'): if type(m) is Conv and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
@ -177,20 +208,20 @@ class Model(nn.Module):
def nms(self, mode=True): # add or remove NMS module def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present: if mode and not present:
print('Adding NMS... ') logger.info('Adding NMS... ')
m = NMS() # module m = NMS() # module
m.f = -1 # from m.f = -1 # from
m.i = self.model[-1].i + 1 # index m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add self.model.add_module(name='%s' % m.i, module=m) # add
self.eval() self.eval()
elif not mode and present: elif not mode and present:
print('Removing NMS... ') logger.info('Removing NMS... ')
self.model = self.model[:-1] # remove self.model = self.model[:-1] # remove
return self return self
def autoshape(self): # add autoShape module def autoshape(self): # add AutoShape module
print('Adding autoShape... ') logger.info('Adding AutoShape... ')
m = autoShape(self) # wrap model m = AutoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m return m
@ -266,12 +297,12 @@ if __name__ == '__main__':
model.train() model.train()
# Profile # Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
# y = model(img, profile=True) # y = model(img, profile=True)
# Tensorboard # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter() # tb_writer = SummaryWriter('.')
# print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/") # logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
# tb_writer.add_graph(model.model, img) # add model to tensorboard # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
# tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard

View File

@ -21,9 +21,10 @@ pandas
# export -------------------------------------- # export --------------------------------------
# coremltools>=4.1 # coremltools>=4.1
# onnx>=1.8.1 # onnx>=1.9.0
# scikit-learn==0.19.2 # for coreml quantization # scikit-learn==0.19.2 # for coreml quantization
# extras -------------------------------------- # extras --------------------------------------
thop # FLOPS computation # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
pycocotools>=2.0 # COCO mAP pycocotools>=2.0 # COCO mAP
thop # FLOPS computation

48
test.py
View File

@ -18,6 +18,7 @@ from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_synchronized from utils.torch_utils import select_device, time_synchronized
@torch.no_grad()
def test(data, def test(data,
weights=None, weights=None,
batch_size=32, batch_size=32,
@ -38,7 +39,8 @@ def test(data,
wandb_logger=None, wandb_logger=None,
compute_loss=None, compute_loss=None,
half_precision=True, half_precision=True,
is_coco=False): is_coco=False,
opt=None):
# Initialize/load model and set device # Initialize/load model and set device
training = model is not None training = model is not None
if training: # called by train.py if training: # called by train.py
@ -49,7 +51,7 @@ def test(data,
device = select_device(opt.device, batch_size=batch_size) device = select_device(opt.device, batch_size=batch_size)
# Directories # Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model # Load model
@ -71,7 +73,7 @@ def test(data,
if isinstance(data, str): if isinstance(data, str):
is_coco = data.endswith('coco.yaml') is_coco = data.endswith('coco.yaml')
with open(data) as f: with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) data = yaml.safe_load(f)
check_dataset(data) # check check_dataset(data) # check
nc = 1 if single_cls else int(data['nc']) # number of classes nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
@ -104,22 +106,21 @@ def test(data,
targets = targets.to(device) targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width nb, _, height, width = img.shape # batch size, channels, height, width
with torch.no_grad(): # Run model
# Run model t = time_synchronized()
t = time_synchronized() out, train_out = model(img, augment=augment) # inference and training outputs
out, train_out = model(img, augment=augment) # inference and training outputs t0 += time_synchronized() - t
t0 += time_synchronized() - t
# Compute loss # Compute loss
if compute_loss: if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
# Run NMS # Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized() t = time_synchronized()
out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True) out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
t1 += time_synchronized() - t t1 += time_synchronized() - t
# Statistics per image # Statistics per image
for si, pred in enumerate(out): for si, pred in enumerate(out):
@ -135,6 +136,8 @@ def test(data,
continue continue
# Predictions # Predictions
if single_cls:
pred[:, 5] = 0
predn = pred.clone() predn = pred.clone()
scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred
@ -185,8 +188,8 @@ def test(data,
# Per target class # Per target class
for cls in torch.unique(tcls_tensor): for cls in torch.unique(tcls_tensor):
ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # target indices
pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # prediction indices
# Search for detections # Search for detections
if pi.shape[0]: if pi.shape[0]:
@ -307,7 +310,7 @@ if __name__ == '__main__':
opt.save_json |= opt.data.endswith('coco.yaml') opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file opt.data = check_file(opt.data) # check file
print(opt) print(opt)
check_requirements() check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
if opt.task in ('train', 'val', 'test'): # run normally if opt.task in ('train', 'val', 'test'): # run normally
test(opt.data, test(opt.data,
@ -323,11 +326,12 @@ if __name__ == '__main__':
save_txt=opt.save_txt | opt.save_hybrid, save_txt=opt.save_txt | opt.save_hybrid,
save_hybrid=opt.save_hybrid, save_hybrid=opt.save_hybrid,
save_conf=opt.save_conf, save_conf=opt.save_conf,
opt=opt
) )
elif opt.task == 'speed': # speed benchmarks elif opt.task == 'speed': # speed benchmarks
for w in opt.weights: for w in opt.weights:
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False) test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, opt=opt)
elif opt.task == 'study': # run over a range of settings and save/plot elif opt.task == 'study': # run over a range of settings and save/plot
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov3.pt yolov3-spp.pt yolov3-tiny.pt # python test.py --task study --data coco.yaml --iou 0.7 --weights yolov3.pt yolov3-spp.pt yolov3-tiny.pt
@ -338,7 +342,7 @@ if __name__ == '__main__':
for i in x: # img-size for i in x: # img-size
print(f'\nRunning {f} point {i}...') print(f'\nRunning {f} point {i}...')
r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json, r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
plots=False) plots=False, opt=opt)
y.append(r + t) # results and times y.append(r + t) # results and times
np.savetxt(f, y, fmt='%10.4g') # save np.savetxt(f, y, fmt='%10.4g') # save
os.system('zip -r study.zip study_*.txt') os.system('zip -r study.zip study_*.txt')

View File

@ -32,7 +32,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,24 +52,23 @@ def train(hyp, opt, device, tb_writer=None):
# Save run settings # Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f: with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False) yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f: with open(save_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False) yaml.safe_dump(vars(opt), f, sort_keys=False)
# Configure # Configure
plots = not opt.evolve # create plots plots = not opt.evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(2 + rank) init_seeds(2 + rank)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict data_dict = yaml.safe_load(f) # data dict
is_coco = opt.data.endswith('coco.yaml')
# Logging- Doing this before checking the dataset. Might update data_dict # Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]: if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict data_dict = wandb_logger.data_dict
if wandb_logger.wandb: if wandb_logger.wandb:
@ -78,12 +77,13 @@ def train(hyp, opt, device, tb_writer=None):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
if pretrained: if pretrained:
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
@ -330,9 +330,9 @@ def train(hyp, opt, device, tb_writer=None):
if plots and ni < 3: if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
# if tb_writer: if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
# tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
elif plots and ni == 10 and wandb_logger.wandb: elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]}) save_dir.glob('train*.jpg') if x.exists()]})
@ -358,6 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
single_cls=opt.single_cls, single_cls=opt.single_cls,
dataloader=testloader, dataloader=testloader,
save_dir=save_dir, save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch, verbose=nc < 50 and final_epoch,
plots=plots and final_epoch, plots=plots and final_epoch,
wandb_logger=wandb_logger, wandb_logger=wandb_logger,
@ -367,8 +368,6 @@ def train(hyp, opt, device, tb_writer=None):
# Write # Write
with open(results_file, 'a') as f: with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if len(opt.name) and opt.bucket:
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
# Log # Log
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
@ -392,7 +391,7 @@ def train(hyp, opt, device, tb_writer=None):
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
'model': deepcopy(model.module if is_parallel(model) else model).half(), 'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
@ -411,41 +410,38 @@ def train(hyp, opt, device, tb_writer=None):
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training
if rank in [-1, 0]: if rank in [-1, 0]:
# Plots logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb: if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]}) if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for m in (last, best) if best.exists() else (last): # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)
# Strip optimizers if not opt.evolve:
final = best if best.exists() else last # final model if is_coco: # COCO dataset
for f in last, best: for m in [last, best] if best.exists() else [last]: # speed, mAP tests
if f.exists(): results, _, _ = test.test(opt.data,
strip_optimizer(f) # strip optimizers batch_size=batch_size * 2,
if opt.bucket: imgsz=imgsz_test,
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload conf_thres=0.001,
if wandb_logger.wandb and not opt.evolve: # Log the stripped model iou_thres=0.7,
wandb_logger.wandb.log_artifact(str(final), type='model', model=attempt_load(m, device).half(),
name='run_' + wandb_logger.wandb_run.id + '_model', single_cls=opt.single_cls,
aliases=['last', 'best', 'stripped']) dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)
# Strip optimizers
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run() wandb_logger.finish_run()
else: else:
dist.destroy_process_group() dist.destroy_process_group()
@ -497,7 +493,7 @@ if __name__ == '__main__':
set_logging(opt.global_rank) set_logging(opt.global_rank)
if opt.global_rank in [-1, 0]: if opt.global_rank in [-1, 0]:
check_git_status() check_git_status()
check_requirements() check_requirements(exclude=('pycocotools', 'thop'))
# Resume # Resume
wandb_run = check_wandb_resume(opt) wandb_run = check_wandb_resume(opt)
@ -506,8 +502,9 @@ if __name__ == '__main__':
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
logger.info('Resuming training from %s' % ckpt) logger.info('Resuming training from %s' % ckpt)
else: else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
@ -515,7 +512,7 @@ if __name__ == '__main__':
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
opt.name = 'evolve' if opt.evolve else opt.name opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
# DDP mode # DDP mode
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size
@ -526,11 +523,12 @@ if __name__ == '__main__':
device = torch.device('cuda', opt.local_rank) device = torch.device('cuda', opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size opt.batch_size = opt.total_batch_size // opt.world_size
# Hyperparameters # Hyperparameters
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps hyp = yaml.safe_load(f) # load hyps
# Train # Train
logger.info(opt) logger.info(opt)

View File

@ -19,23 +19,6 @@ class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX
class MemoryEfficientSwish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * torch.sigmoid(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
return grad_output * (sx * (1 + x * (1 - sx)))
def forward(self, x):
return self.F.apply(x)
# Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- # Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
class Mish(nn.Module): class Mish(nn.Module):
@staticmethod @staticmethod
@ -70,3 +53,46 @@ class FReLU(nn.Module):
def forward(self, x): def forward(self, x):
return torch.max(x, self.bn(self.conv(x))) return torch.max(x, self.bn(self.conv(x)))
# ACON https://arxiv.org/pdf/2009.04759.pdf ----------------------------------------------------------------------------
class AconC(nn.Module):
r""" ACON activation (activate or not).
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""
def __init__(self, c1):
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
def forward(self, x):
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
class MetaAconC(nn.Module):
r""" ACON activation (activate or not).
MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""
def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
super().__init__()
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
# self.bn1 = nn.BatchNorm2d(c2)
# self.bn2 = nn.BatchNorm2d(c1)
def forward(self, x):
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

View File

@ -3,7 +3,6 @@
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
from scipy.cluster.vq import kmeans
from tqdm import tqdm from tqdm import tqdm
from utils.general import colorstr from utils.general import colorstr
@ -76,6 +75,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
Usage: Usage:
from utils.autoanchor import *; _ = kmean_anchors() from utils.autoanchor import *; _ = kmean_anchors()
""" """
from scipy.cluster.vq import kmeans
thr = 1. / thr thr = 1. / thr
prefix = colorstr('autoanchor: ') prefix = colorstr('autoanchor: ')
@ -102,7 +103,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
if isinstance(path, str): # *.yaml file if isinstance(path, str): # *.yaml file
with open(path) as f: with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict data_dict = yaml.safe_load(f) # model dict
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
else: else:

View File

@ -19,7 +19,7 @@ for last in path.rglob('*/**/last.pt'):
# Load opt.yaml # Load opt.yaml
with open(last.parent.parent / 'opt.yaml') as f: with open(last.parent.parent / 'opt.yaml') as f:
opt = yaml.load(f, Loader=yaml.SafeLoader) opt = yaml.safe_load(f)
# Get device count # Get device count
d = opt['device'].split(',') # devices d = opt['device'].split(',') # devices

View File

@ -7,7 +7,7 @@
cd home/ubuntu cd home/ubuntu
if [ ! -d yolov5 ]; then if [ ! -d yolov5 ]; then
echo "Running first-time script." # install dependencies, download COCO, pull Docker echo "Running first-time script." # install dependencies, download COCO, pull Docker
git clone https://github.com/ultralytics/yolov5 && sudo chmod -R 777 yolov5 git clone https://github.com/ultralytics/yolov5 -b master && sudo chmod -R 777 yolov5
cd yolov5 cd yolov5
bash data/scripts/get_coco.sh && echo "Data done." & bash data/scripts/get_coco.sh && echo "Data done." &
sudo docker pull ultralytics/yolov5:latest && echo "Docker done." & sudo docker pull ultralytics/yolov5:latest && echo "Docker done." &

View File

@ -1,6 +1,7 @@
# Dataset utils and dataloaders # Dataset utils and dataloaders
import glob import glob
import hashlib
import logging import logging
import math import math
import os import os
@ -36,9 +37,12 @@ for orientation in ExifTags.TAGS.keys():
break break
def get_hash(files): def get_hash(paths):
# Returns a single hash value of a list of files # Returns a single hash value of a list of paths (files or dirs)
return sum(os.path.getsize(f) for f in files if os.path.isfile(f)) size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.md5(str(size).encode()) # hash sizes
h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img): def exif_size(img):
@ -172,12 +176,12 @@ class LoadImages: # for inference
ret_val, img0 = self.cap.read() ret_val, img0 = self.cap.read()
self.frame += 1 self.frame += 1
print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='') print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
else: else:
# Read image # Read image
self.count += 1 self.count += 1
img0 = cv2.imread(path) # BGR img0 = cv2.imread(path, -1) # BGR (-1 is IMREAD_UNCHANGED)
assert img0 is not None, 'Image Not Found ' + path assert img0 is not None, 'Image Not Found ' + path
print(f'image {self.count}/{self.nf} {path}: ', end='') print(f'image {self.count}/{self.nf} {path}: ', end='')
@ -193,7 +197,7 @@ class LoadImages: # for inference
def new_video(self, path): def new_video(self, path):
self.frame = 0 self.frame = 0
self.cap = cv2.VideoCapture(path) self.cap = cv2.VideoCapture(path)
self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
def __len__(self): def __len__(self):
return self.nf # number of files return self.nf # number of files
@ -270,26 +274,27 @@ class LoadStreams: # multiple IP or RTSP cameras
sources = [sources] sources = [sources]
n = len(sources) n = len(sources)
self.imgs = [None] * n self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later self.sources = [clean_str(x) for x in sources] # clean source names for later
for i, s in enumerate(sources): for i, s in enumerate(sources): # index, source
# Start the thread to read frames from the video stream # Start thread to read frames from video stream
print(f'{i + 1}/{n}: {s}... ', end='') print(f'{i + 1}/{n}: {s}... ', end='')
url = eval(s) if s.isnumeric() else s if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
if 'youtube.com/' in url or 'youtu.be/' in url: # if source is YouTube video
check_requirements(('pafy', 'youtube_dl')) check_requirements(('pafy', 'youtube_dl'))
import pafy import pafy
url = pafy.new(url).getbest(preftype="mp4").url s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
cap = cv2.VideoCapture(url) s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'Failed to open {s}' assert cap.isOpened(), f'Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.fps = cap.get(cv2.CAP_PROP_FPS) % 100 self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
_, self.imgs[i] = cap.read() # guarantee first frame _, self.imgs[i] = cap.read() # guarantee first frame
thread = Thread(target=self.update, args=([i, cap]), daemon=True) self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
print(f' success ({w}x{h} at {self.fps:.2f} FPS).') print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
thread.start() self.threads[i].start()
print('') # newline print('') # newline
# check for common shapes # check for common shapes
@ -298,18 +303,17 @@ class LoadStreams: # multiple IP or RTSP cameras
if not self.rect: if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
def update(self, index, cap): def update(self, i, cap):
# Read next stream frame in a daemon thread # Read stream `i` frames in daemon thread
n = 0 n, f = 0, self.frames[i]
while cap.isOpened(): while cap.isOpened() and n < f:
n += 1 n += 1
# _, self.imgs[index] = cap.read() # _, self.imgs[index] = cap.read()
cap.grab() cap.grab()
if n == 4: # read every 4th frame if n % 4: # read every 4th frame
success, im = cap.retrieve() success, im = cap.retrieve()
self.imgs[index] = im if success else self.imgs[index] * 0 self.imgs[i] = im if success else self.imgs[i] * 0
n = 0 time.sleep(1 / self.fps[i]) # wait time
time.sleep(1 / self.fps) # wait time
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
@ -317,12 +321,12 @@ class LoadStreams: # multiple IP or RTSP cameras
def __next__(self): def __next__(self):
self.count += 1 self.count += 1
img0 = self.imgs.copy() if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
if cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows() cv2.destroyAllWindows()
raise StopIteration raise StopIteration
# Letterbox # Letterbox
img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0] img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
# Stack # Stack
@ -383,7 +387,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
if cache_path.is_file(): if cache_path.is_file():
cache, exists = torch.load(cache_path), True # load cache, exists = torch.load(cache_path), True # load
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed if cache['hash'] != get_hash(self.label_files + self.img_files): # changed
cache, exists = self.cache_labels(cache_path, prefix), False # re-cache cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else: else:
cache, exists = self.cache_labels(cache_path, prefix), False # cache cache, exists = self.cache_labels(cache_path, prefix), False # cache
@ -470,7 +474,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if os.path.isfile(lb_file): if os.path.isfile(lb_file):
nf += 1 # label found nf += 1 # label found
with open(lb_file, 'r') as f: with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines()] l = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in l]): # is segment if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32) classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
@ -490,20 +494,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
x[im_file] = [l, shape, segments] x[im_file] = [l, shape, segments]
except Exception as e: except Exception as e:
nc += 1 nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close() pbar.close()
if nf == 0: if nf == 0:
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files) x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1 x['results'] = nf, nm, ne, nc, i + 1
x['version'] = 0.1 # cache version x['version'] = 0.2 # cache version
torch.save(x, path) # save for next time try:
logging.info(f'{prefix}New cache created: {path}') torch.save(x, path) # save cache for next time
logging.info(f'{prefix}New cache created: {path}')
except Exception as e:
logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
return x return x
def __len__(self): def __len__(self):
@ -634,10 +641,10 @@ def load_image(self, index):
img = cv2.imread(path) # BGR img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # resize image to img_size r = self.img_size / max(h0, w0) # ratio
if r != 1: # always resize down, only resize up if training with augmentation if r != 1: # if sizes are not equal
interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
else: else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized

View File

@ -0,0 +1,68 @@
# Flask REST API
[REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the YOLOv5s model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/).
## Requirements
[Flask](https://palletsprojects.com/p/flask/) is required. Install with:
```shell
$ pip install Flask
```
## Run
After Flask installation run:
```shell
$ python3 restapi.py --port 5000
```
Then use [curl](https://curl.se/) to perform a request:
```shell
$ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s'`
```
The model inference results are returned as a JSON response:
```json
[
{
"class": 0,
"confidence": 0.8900438547,
"height": 0.9318675399,
"name": "person",
"width": 0.3264600933,
"xcenter": 0.7438579798,
"ycenter": 0.5207948685
},
{
"class": 0,
"confidence": 0.8440024257,
"height": 0.7155083418,
"name": "person",
"width": 0.6546785235,
"xcenter": 0.427829951,
"ycenter": 0.6334488392
},
{
"class": 27,
"confidence": 0.3771208823,
"height": 0.3902671337,
"name": "tie",
"width": 0.0696444362,
"xcenter": 0.3675483763,
"ycenter": 0.7991207838
},
{
"class": 27,
"confidence": 0.3527112305,
"height": 0.1540903747,
"name": "tie",
"width": 0.0336618312,
"xcenter": 0.7814827561,
"ycenter": 0.5065554976
}
]
```
An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given in `example_request.py`

View File

@ -0,0 +1,13 @@
"""Perform test request"""
import pprint
import requests
DETECTION_URL = "http://localhost:5000/v1/object-detection/yolov5s"
TEST_IMAGE = "zidane.jpg"
image_data = open(TEST_IMAGE, "rb").read()
response = requests.post(DETECTION_URL, files={"image": image_data}).json()
pprint.pprint(response)

View File

@ -0,0 +1,37 @@
"""
Run a rest API exposing the yolov5s object detection model
"""
import argparse
import io
import torch
from PIL import Image
from flask import Flask, request
app = Flask(__name__)
DETECTION_URL = "/v1/object-detection/yolov5s"
@app.route(DETECTION_URL, methods=["POST"])
def predict():
if not request.method == "POST":
return
if request.files.get("image"):
image_file = request.files["image"]
image_bytes = image_file.read()
img = Image.open(io.BytesIO(image_bytes))
results = model(img, size=640) # reduce size=320 for faster inference
return results.pandas().xyxy[0].to_json(orient="records")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flask API exposing YOLOv3 model")
parser.add_argument("--port", default=5000, type=int, help="port number")
args = parser.parse_args()
model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True) # force_reload to recache
app.run(host="0.0.0.0", port=args.port) # debug=True causes Restarting with stat

View File

@ -9,11 +9,14 @@ import random
import re import re
import subprocess import subprocess
import time import time
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pkg_resources as pkg
import torch import torch
import torchvision import torchvision
import yaml import yaml
@ -30,10 +33,10 @@ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with Py
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
def set_logging(rank=-1): def set_logging(rank=-1, verbose=True):
logging.basicConfig( logging.basicConfig(
format="%(message)s", format="%(message)s",
level=logging.INFO if rank in [-1, 0] else logging.WARN) level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
def init_seeds(seed=0): def init_seeds(seed=0):
@ -49,16 +52,30 @@ def get_latest_run(search_dir='.'):
return max(last_list, key=os.path.getctime) if last_list else '' return max(last_list, key=os.path.getctime) if last_list else ''
def isdocker(): def is_docker():
# Is environment a Docker container # Is environment a Docker container
return Path('/workspace').exists() # or Path('/.dockerenv').exists() return Path('/workspace').exists() # or Path('/.dockerenv').exists()
def is_colab():
# Is environment a Google Colab instance
try:
import google.colab
return True
except Exception as e:
return False
def emojis(str=''): def emojis(str=''):
# Return platform-dependent emoji-safe version of string # Return platform-dependent emoji-safe version of string
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
def file_size(file):
# Return file size in MB
return Path(file).stat().st_size / 1e6
def check_online(): def check_online():
# Check internet connectivity # Check internet connectivity
import socket import socket
@ -74,7 +91,7 @@ def check_git_status():
print(colorstr('github: '), end='') print(colorstr('github: '), end='')
try: try:
assert Path('.git').exists(), 'skipping check (not a git repository)' assert Path('.git').exists(), 'skipping check (not a git repository)'
assert not isdocker(), 'skipping check (Docker image)' assert not is_docker(), 'skipping check (Docker image)'
assert check_online(), 'skipping check (offline)' assert check_online(), 'skipping check (offline)'
cmd = 'git fetch && git config --get remote.origin.url' cmd = 'git fetch && git config --get remote.origin.url'
@ -91,10 +108,19 @@ def check_git_status():
print(e) print(e)
def check_python(minimum='3.7.0', required=True):
# Check current python version vs. required python version
current = platform.python_version()
result = pkg.parse_version(current) >= pkg.parse_version(minimum)
if required:
assert result, f'Python {minimum} required by YOLOv3, but Python {current} is currently installed'
return result
def check_requirements(requirements='requirements.txt', exclude=()): def check_requirements(requirements='requirements.txt', exclude=()):
# Check installed dependencies meet requirements (pass *.txt file or list of packages) # Check installed dependencies meet requirements (pass *.txt file or list of packages)
import pkg_resources as pkg
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version
if isinstance(requirements, (str, Path)): # requirements.txt file if isinstance(requirements, (str, Path)): # requirements.txt file
file = Path(requirements) file = Path(requirements)
if not file.exists(): if not file.exists():
@ -110,8 +136,11 @@ def check_requirements(requirements='requirements.txt', exclude=()):
pkg.require(r) pkg.require(r)
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
n += 1 n += 1
print(f"{prefix} {e.req} not found and is required by YOLOv3, attempting auto-update...") print(f"{prefix} {r} not found and is required by YOLOv3, attempting auto-update...")
print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode()) try:
print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
except Exception as e:
print(f'{prefix} {e}')
if n: # if packages updated if n: # if packages updated
source = file.resolve() if 'file' in locals() else requirements source = file.resolve() if 'file' in locals() else requirements
@ -131,7 +160,8 @@ def check_img_size(img_size, s=32):
def check_imshow(): def check_imshow():
# Check if environment supports image displays # Check if environment supports image displays
try: try:
assert not isdocker(), 'cv2.imshow() is disabled in Docker environments' assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
cv2.imshow('test', np.zeros((1, 1, 3))) cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1) cv2.waitKey(1)
cv2.destroyAllWindows() cv2.destroyAllWindows()
@ -143,12 +173,19 @@ def check_imshow():
def check_file(file): def check_file(file):
# Search for file if not found # Search/download file (if necessary) and return path
if Path(file).is_file() or file == '': file = str(file) # convert to str()
if Path(file).is_file() or file == '': # exists
return file return file
else: elif file.startswith(('http://', 'https://')): # download
url, file = file, Path(file).name
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
return file
else: # search
files = glob.glob('./**/' + file, recursive=True) # find file files = glob.glob('./**/' + file, recursive=True) # find file
assert len(files), f'File Not Found: {file}' # assert file was found assert len(files), f'File not found: {file}' # assert file was found
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
return files[0] # return file return files[0] # return file
@ -161,18 +198,54 @@ def check_dataset(dict):
if not all(x.exists() for x in val): if not all(x.exists() for x in val):
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()]) print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
if s and len(s): # download script if s and len(s): # download script
print('Downloading %s ...' % s)
if s.startswith('http') and s.endswith('.zip'): # URL if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename f = Path(s).name # filename
print(f'Downloading {s} ...')
torch.hub.download_url_to_file(s, f) torch.hub.download_url_to_file(s, f)
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
else: # bash script elif s.startswith('bash '): # bash script
print(f'Running {s} ...')
r = os.system(s) r = os.system(s)
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value else: # python script
r = exec(s) # return None
print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
else: else:
raise Exception('Dataset not found.') raise Exception('Dataset not found.')
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
# Multi-threaded file download and unzip function
def download_one(url, dir):
# Download 1 file
f = dir / Path(url).name # filename
if not f.exists():
print(f'Downloading {url} to {f}...')
if curl:
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
else:
torch.hub.download_url_to_file(url, f, progress=True) # torch download
if unzip and f.suffix in ('.zip', '.gz'):
print(f'Unzipping {f}...')
if f.suffix == '.zip':
s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite
elif f.suffix == '.gz':
s = f'tar xfz {f} --directory {f.parent}' # unzip
if delete: # delete zip file after unzip
s += f' && rm {f}'
os.system(s)
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
pool.close()
pool.join()
else:
for u in tuple(url) if isinstance(url, str) else url:
download_one(u, dir)
def make_divisible(x, divisor): def make_divisible(x, divisor):
# Returns x evenly divisible by divisor # Returns x evenly divisible by divisor
return math.ceil(x / divisor) * divisor return math.ceil(x / divisor) * divisor
@ -419,7 +492,7 @@ def wh_iou(wh1, wh2):
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()): labels=(), max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results """Runs Non-Maximum Suppression (NMS) on inference results
Returns: Returns:
@ -429,9 +502,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
nc = prediction.shape[2] - 5 # number of classes nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates xc = prediction[..., 4] > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings # Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections redundant = True # require redundant detections
@ -550,14 +626,14 @@ def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
results = tuple(x[0, :7]) results = tuple(x[0, :7])
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
yaml.dump(hyp, f, sort_keys=False) yaml.safe_dump(hyp, f, sort_keys=False)
if bucket: if bucket:
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
def apply_classifier(x, model, img, im0): def apply_classifier(x, model, img, im0):
# applies a second stage classifier to yolo outputs # Apply a second stage classifier to yolo outputs
im0 = [im0] if isinstance(im0, np.ndarray) else im0 im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image for i, d in enumerate(x): # per image
if d is not None and len(d): if d is not None and len(d):
@ -591,14 +667,33 @@ def apply_classifier(x, model, img, im0):
return x return x
def increment_path(path, exist_ok=True, sep=''): def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
# Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
xyxy = torch.tensor(xyxy).view(-1, 4)
b = xyxy2xywh(xyxy) # boxes
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = xywh2xyxy(b).long()
clip_coords(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
if save:
cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
return crop
def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic path = Path(path) # os-agnostic
if (path.exists() and exist_ok) or (not path.exists()): if path.exists() and not exist_ok:
return str(path) suffix = path.suffix
else: path = path.with_suffix('')
dirs = glob.glob(f"{path}{sep}*") # similar paths dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m] # indices i = [int(m.groups()[0]) for m in matches if m] # indices
n = max(i) + 1 if i else 2 # increment number n = max(i) + 1 if i else 2 # increment number
return f"{path}{sep}{n}" # update path path = Path(f"{path}{sep}{n}{suffix}") # update path
dir = path if path.suffix == '' else path.parent # directory
if not dir.exists() and mkdir:
dir.mkdir(parents=True, exist_ok=True) # make directory
return path

View File

@ -16,40 +16,57 @@ def gsutil_getsize(url=''):
return eval(s.split(' ')[0]) if len(s) else 0 # bytes return eval(s.split(' ')[0]) if len(s) else 0 # bytes
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
try: # GitHub
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file))
assert file.exists() and file.stat().st_size > min_bytes # check
except Exception as e: # GCP
file.unlink(missing_ok=True) # remove partial downloads
print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {error_msg or url}')
print('')
def attempt_download(file, repo='ultralytics/yolov3'): def attempt_download(file, repo='ultralytics/yolov3'):
# Attempt file download if does not exist # Attempt file download if does not exist
file = Path(str(file).strip().replace("'", '').lower()) file = Path(str(file).strip().replace("'", ''))
if not file.exists(): if not file.exists():
# URL specified
name = file.name
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
safe_download(file=name, url=url, min_bytes=1E5)
return name
# GitHub assets
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
try: try:
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...]
tag = response['tag_name'] # i.e. 'v1.0' tag = response['tag_name'] # i.e. 'v1.0'
except: # fallback plan except: # fallback plan
assets = ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt'] assets = ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt']
tag = subprocess.check_output('git tag', shell=True).decode().split()[-1] try:
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
except:
tag = 'v9.5.0' # current release
name = file.name
if name in assets: if name in assets:
msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' safe_download(file,
redundant = False # second download option url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
try: # GitHub # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
url = f'https://github.com/{repo}/releases/download/{tag}/{name}' min_bytes=1E5,
print(f'Downloading {url} to {file}...') error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
torch.hub.download_url_to_file(url, file)
assert file.exists() and file.stat().st_size > 1E6 # check return str(file)
except Exception as e: # GCP
print(f'Download error: {e}')
assert redundant, 'No secondary mirror'
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
print(f'Downloading {url} to {file}...')
os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights)
finally:
if not file.exists() or file.stat().st_size < 1E6: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {msg}')
print('')
return
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):

View File

@ -145,7 +145,7 @@ class ConfusionMatrix:
for i, gc in enumerate(gt_classes): for i, gc in enumerate(gt_classes):
j = m0 == i j = m0 == i
if n and sum(j) == 1: if n and sum(j) == 1:
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct self.matrix[detection_classes[m1[j]], gc] += 1 # correct
else: else:
self.matrix[self.nc, gc] += 1 # background FP self.matrix[self.nc, gc] += 1 # background FP

View File

@ -16,7 +16,6 @@ import seaborn as sns
import torch import torch
import yaml import yaml
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from scipy.signal import butter, filtfilt
from utils.general import xywh2xyxy, xyxy2xywh from utils.general import xywh2xyxy, xyxy2xywh
from utils.metrics import fitness from utils.metrics import fitness
@ -26,12 +25,25 @@ matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg') # for writing to files only matplotlib.use('Agg') # for writing to files only
def color_list(): class Colors:
# Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb # Ultralytics color palette https://ultralytics.com/
def hex2rgb(h): def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex]
self.n = len(self.palette)
def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
colors = Colors() # create instance for 'from utils.plots import colors'
def hist2d(x, y, n=100): def hist2d(x, y, n=100):
@ -44,6 +56,8 @@ def hist2d(x, y, n=100):
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
from scipy.signal import butter, filtfilt
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
def butter_lowpass(cutoff, fs, order): def butter_lowpass(cutoff, fs, order):
nyq = 0.5 * fs nyq = 0.5 * fs
@ -54,32 +68,32 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
return filtfilt(b, a, data) # forward-backward filter return filtfilt(b, a, data) # forward-backward filter
def plot_one_box(x, img, color=None, label=None, line_thickness=3): def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
# Plots one bounding box on image img # Plots one bounding box on image 'im' using OpenCV
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
color = color or [random.randint(0, 255) for _ in range(3)] tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label: if label:
tf = max(tl - 1, 1) # font thickness tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None): def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
img = Image.fromarray(img) # Plots one bounding box on image 'im' using PIL
draw = ImageDraw.Draw(img) im = Image.fromarray(im)
line_thickness = line_thickness or max(int(min(img.size) / 200), 2) draw = ImageDraw.Draw(im)
draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
draw.rectangle(box, width=line_thickness, outline=color) # plot
if label: if label:
fontsize = max(round(max(img.size) / 40), 12) font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
font = ImageFont.truetype("Arial.ttf", fontsize)
txt_width, txt_height = font.getsize(label) txt_width, txt_height = font.getsize(label)
draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
return np.asarray(img) return np.asarray(im)
def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
@ -135,7 +149,6 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
h = math.ceil(scale_factor * h) h = math.ceil(scale_factor * h)
w = math.ceil(scale_factor * w) w = math.ceil(scale_factor * w)
colors = color_list() # list of colors
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, img in enumerate(images): for i, img in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect if i == max_subplots: # if last batch has fewer images than we expect
@ -166,7 +179,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
boxes[[1, 3]] += block_y boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T): for j, box in enumerate(boxes.T):
cls = int(classes[j]) cls = int(classes[j])
color = colors[cls % len(colors)] color = colors(cls)
cls = names[cls] if names else cls cls = names[cls] if names else cls
if labels or conf[j] > 0.25: # 0.25 conf thresh if labels or conf[j] > 0.25: # 0.25 conf thresh
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
@ -274,7 +287,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
print('Plotting labels... ') print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
# seaborn correlogram # seaborn correlogram
@ -285,7 +297,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
# matplotlib labels # matplotlib labels
matplotlib.use('svg') # faster matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances') ax[0].set_ylabel('instances')
if 0 < len(names) < 30: if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names))) ax[0].set_xticks(range(len(names)))
@ -300,7 +313,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
for cls, *box in labels[:1000]: for cls, *box in labels[:1000]:
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img) ax[1].imshow(img)
ax[1].axis('off') ax[1].axis('off')
@ -321,7 +334,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt
with open(yaml_file) as f: with open(yaml_file) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) hyp = yaml.safe_load(f)
x = np.loadtxt('evolve.txt', ndmin=2) x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x) f = fitness(x)
# weights = (f - f.min()) ** 2 # for weighted results # weights = (f - f.min()) ** 2 # for weighted results

View File

@ -72,11 +72,12 @@ def select_device(device='', batch_size=None):
cuda = not cpu and torch.cuda.is_available() cuda = not cpu and torch.cuda.is_available()
if cuda: if cuda:
n = torch.cuda.device_count() devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
if n > 1 and batch_size: # check that batch_size is compatible with device_count n = len(devices) # device count
if n > 1 and batch_size: # check batch_size is divisible by device_count
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
space = ' ' * len(s) space = ' ' * len(s)
for i, d in enumerate(device.split(',') if device else range(n)): for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else: else:
@ -133,9 +134,15 @@ def profile(x, ops, n=100, device=None):
def is_parallel(model): def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}

View File

@ -9,7 +9,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
def create_dataset_artifact(opt): def create_dataset_artifact(opt):
with open(opt.data) as f: with open(opt.data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict data = yaml.safe_load(f) # data dict
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')
@ -17,7 +17,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project') parser.add_argument('--project', type=str, default='YOLOv3', help='name of W&B Project')
opt = parser.parse_args() opt = parser.parse_args()
opt.resume = False # Explicitly disallow resume check for dataset upload job opt.resume = False # Explicitly disallow resume check for dataset upload job

View File

@ -1,3 +1,4 @@
"""Utilities and tools for tracking runs with Weights & Biases."""
import json import json
import sys import sys
from pathlib import Path from pathlib import Path
@ -9,7 +10,7 @@ from tqdm import tqdm
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
from utils.datasets import img2label_paths from utils.datasets import img2label_paths
from utils.general import colorstr, xywh2xyxy, check_dataset from utils.general import colorstr, xywh2xyxy, check_dataset, check_file
try: try:
import wandb import wandb
@ -35,8 +36,9 @@ def get_run_info(run_path):
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem run_id = run_path.stem
project = run_path.parent.stem project = run_path.parent.stem
entity = run_path.parent.parent.stem
model_artifact_name = 'run_' + run_id + '_model' model_artifact_name = 'run_' + run_id + '_model'
return run_id, project, model_artifact_name return entity, project, run_id, model_artifact_name
def check_wandb_resume(opt): def check_wandb_resume(opt):
@ -44,9 +46,9 @@ def check_wandb_resume(opt):
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs if opt.global_rank not in [-1, 0]: # For resuming DDP runs
run_id, project, model_artifact_name = get_run_info(opt.resume) entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api() api = wandb.Api()
artifact = api.artifact(project + '/' + model_artifact_name + ':latest') artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
modeldir = artifact.download() modeldir = artifact.download()
opt.weights = str(Path(modeldir) / "last.pt") opt.weights = str(Path(modeldir) / "last.pt")
return True return True
@ -54,8 +56,8 @@ def check_wandb_resume(opt):
def process_wandb_config_ddp_mode(opt): def process_wandb_config_ddp_mode(opt):
with open(opt.data) as f: with open(check_file(opt.data)) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict data_dict = yaml.safe_load(f) # data dict
train_dir, val_dir = None, None train_dir, val_dir = None, None
if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
api = wandb.Api() api = wandb.Api()
@ -73,11 +75,23 @@ def process_wandb_config_ddp_mode(opt):
if train_dir or val_dir: if train_dir or val_dir:
ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
with open(ddp_data_path, 'w') as f: with open(ddp_data_path, 'w') as f:
yaml.dump(data_dict, f) yaml.safe_dump(data_dict, f)
opt.data = ddp_data_path opt.data = ddp_data_path
class WandbLogger(): class WandbLogger():
"""Log training runs, datasets, models, and predictions to Weights & Biases.
This logger sends information to W&B at wandb.ai. By default, this information
includes hyperparameters, system configuration and metrics, model metrics,
and basic data metrics and analyses.
By providing additional command line arguments to train.py, datasets,
models and predictions can also be logged.
For more on how this logger is used, see the Weights & Biases documentation:
https://docs.wandb.com/guides/integrations/yolov5
"""
def __init__(self, opt, name, run_id, data_dict, job_type='Training'): def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine -- # Pre-training routine --
self.job_type = job_type self.job_type = job_type
@ -85,16 +99,17 @@ class WandbLogger():
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
run_id, project, model_artifact_name = get_run_info(opt.resume) entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
assert wandb, 'install wandb to resume wandb runs' assert wandb, 'install wandb to resume wandb runs'
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') self.wandb_run = wandb.init(id=run_id, project=project, entity=entity, resume='allow')
opt.resume = model_artifact_name opt.resume = model_artifact_name
elif self.wandb: elif self.wandb:
self.wandb_run = wandb.init(config=opt, self.wandb_run = wandb.init(config=opt,
resume="allow", resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, project='YOLOv3' if opt.project == 'runs/train' else Path(opt.project).stem,
entity=opt.entity,
name=name, name=name,
job_type=job_type, job_type=job_type,
id=run_id) if not wandb.run else wandb.run id=run_id) if not wandb.run else wandb.run
@ -110,17 +125,17 @@ class WandbLogger():
self.data_dict = self.check_and_upload_dataset(opt) self.data_dict = self.check_and_upload_dataset(opt)
else: else:
prefix = colorstr('wandb: ') prefix = colorstr('wandb: ')
print(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") print(f"{prefix}Install Weights & Biases for YOLOv3 logging with 'pip install wandb' (recommended)")
def check_and_upload_dataset(self, opt): def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset' assert wandb, 'Install wandb to upload dataset'
check_dataset(self.data_dict) check_dataset(self.data_dict)
config_path = self.log_dataset_artifact(opt.data, config_path = self.log_dataset_artifact(check_file(opt.data),
opt.single_cls, opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) 'YOLOv3' if opt.project == 'runs/train' else Path(opt.project).stem)
print("Created dataset config file ", config_path) print("Created dataset config file ", config_path)
with open(config_path) as f: with open(config_path) as f:
wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader) wandb_data_dict = yaml.safe_load(f)
return wandb_data_dict return wandb_data_dict
def setup_training(self, opt, data_dict): def setup_training(self, opt, data_dict):
@ -158,7 +173,8 @@ class WandbLogger():
def download_dataset_artifact(self, path, alias): def download_dataset_artifact(self, path, alias):
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
dataset_artifact = wandb.use_artifact(artifact_path.as_posix())
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
datadir = dataset_artifact.download() datadir = dataset_artifact.download()
return datadir, dataset_artifact return datadir, dataset_artifact
@ -171,8 +187,8 @@ class WandbLogger():
modeldir = model_artifact.download() modeldir = model_artifact.download()
epochs_trained = model_artifact.metadata.get('epochs_trained') epochs_trained = model_artifact.metadata.get('epochs_trained')
total_epochs = model_artifact.metadata.get('total_epochs') total_epochs = model_artifact.metadata.get('total_epochs')
assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % ( is_finished = total_epochs is None
total_epochs) assert not is_finished, 'training is finished, can only resume incomplete runs.'
return modeldir, model_artifact return modeldir, model_artifact
return None, None return None, None
@ -187,18 +203,18 @@ class WandbLogger():
}) })
model_artifact.add_file(str(path / 'last.pt'), name='last.pt') model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
wandb.log_artifact(model_artifact, wandb.log_artifact(model_artifact,
aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
print("Saving model artifact on epoch ", epoch + 1) print("Saving model artifact on epoch ", epoch + 1)
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
with open(data_file) as f: with open(data_file) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict data = yaml.safe_load(f) # data dict
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
data['train']), names, name='train') if data.get('train') else None data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None
self.val_artifact = self.create_dataset_table(LoadImagesAndLabels( self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
data['val']), names, name='val') if data.get('val') else None data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
if data.get('train'): if data.get('train'):
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
if data.get('val'): if data.get('val'):
@ -206,7 +222,7 @@ class WandbLogger():
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
data.pop('download', None) data.pop('download', None)
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.dump(data, f) yaml.safe_dump(data, f)
if self.job_type == 'Training': # builds correct artifact pipeline graph if self.job_type == 'Training': # builds correct artifact pipeline graph
self.wandb_run.use_artifact(self.val_artifact) self.wandb_run.use_artifact(self.val_artifact)
@ -243,16 +259,12 @@ class WandbLogger():
table = wandb.Table(columns=["id", "train_image", "Classes", "name"]) table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)): for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
height, width = shapes[0]
labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
box_data, img_classes = [], {} box_data, img_classes = [], {}
for cls, *xyxy in labels[:, 1:].tolist(): for cls, *xywh in labels[:, 1:].tolist():
cls = int(cls) cls = int(cls)
box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},
"class_id": cls, "class_id": cls,
"box_caption": "%s" % (class_to_id[cls]), "box_caption": "%s" % (class_to_id[cls])})
"scores": {"acc": 1},
"domain": "pixel"})
img_classes[cls] = class_to_id[cls] img_classes[cls] = class_to_id[cls]
boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes), table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
@ -294,7 +306,7 @@ class WandbLogger():
if self.result_artifact: if self.result_artifact:
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
self.result_artifact.add(train_results, 'result') self.result_artifact.add(train_results, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch), wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
('best' if best_result else '')]) ('best' if best_result else '')])
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")