34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
import torch
|
|
from models.yolo import *
|
|
|
|
# Define the YOLOv3 model
|
|
# model = YOLO(num_classes=80)
|
|
model = Model('models/yolov3.yaml')
|
|
|
|
# Load the PyTorch .pt file
|
|
model.load_state_dict(torch.load('/home/parallels/ros2_ws/src/darknet_ros_fp16/darknet_ros/darknet_ros/yolo_network_config/weights/pipe_yolo3.pt'))
|
|
|
|
# Create a dictionary of layer names and weights
|
|
layer_weights = {}
|
|
for name, param in model.named_parameters():
|
|
if name.endswith('.bias'):
|
|
continue
|
|
layer_name = name.rsplit('.', 1)[0]
|
|
if layer_name not in layer_weights:
|
|
layer_weights[layer_name] = []
|
|
layer_weights[layer_name].append(param.detach().cpu().numpy())
|
|
|
|
# Write the weights to a binary file in Darknet's .weights format
|
|
with open('yolov3_pipe.weights', 'wb') as f:
|
|
for layer_name, weights in layer_weights.items():
|
|
header = [0, 0, 0, 0]
|
|
header[0] = weights[0].shape[0] # Number of filters
|
|
header[1] = weights[0].shape[1] # Number of channels
|
|
header[2] = weights[0].shape[2] # Filter height
|
|
header[3] = weights[0].shape[3] # Filter width
|
|
f.write(bytes(header))
|
|
|
|
for w in weights:
|
|
w = w.flatten()
|
|
f.write(w.tobytes())
|