greenhouse/yolov3/convert_pt_to_weight.py

34 lines
1.2 KiB
Python

import torch
from models.yolo import *
# Define the YOLOv3 model
# model = YOLO(num_classes=80)
model = Model('models/yolov3.yaml')
# Load the PyTorch .pt file
model.load_state_dict(torch.load('/home/parallels/ros2_ws/src/darknet_ros_fp16/darknet_ros/darknet_ros/yolo_network_config/weights/pipe_yolo3.pt'))
# Create a dictionary of layer names and weights
layer_weights = {}
for name, param in model.named_parameters():
if name.endswith('.bias'):
continue
layer_name = name.rsplit('.', 1)[0]
if layer_name not in layer_weights:
layer_weights[layer_name] = []
layer_weights[layer_name].append(param.detach().cpu().numpy())
# Write the weights to a binary file in Darknet's .weights format
with open('yolov3_pipe.weights', 'wb') as f:
for layer_name, weights in layer_weights.items():
header = [0, 0, 0, 0]
header[0] = weights[0].shape[0] # Number of filters
header[1] = weights[0].shape[1] # Number of channels
header[2] = weights[0].shape[2] # Filter height
header[3] = weights[0].shape[3] # Filter width
f.write(bytes(header))
for w in weights:
w = w.flatten()
f.write(w.tobytes())