multi_thread dataloader
This commit is contained in:
+8
-3
@@ -7,7 +7,6 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from torch.utils.data import Dataset
|
||||
from utils.utils import xyxy2xywh
|
||||
|
||||
|
||||
@@ -114,10 +113,11 @@ class LoadImagesAndLabels: # for training
|
||||
|
||||
def __getitem__(self, index):
|
||||
imgs, labels0, img_paths, img_shapes = self.load_images(index, index + 1)
|
||||
labels0[:,0] = index % self.batch_size
|
||||
|
||||
labels0[:, 0] = index % self.batch_size
|
||||
labels = torch.zeros(100, 6)
|
||||
labels[:min(len(labels0), 100)] = labels0 # max 100 labels per image
|
||||
|
||||
return imgs.squeeze(0), labels, img_paths, img_shapes
|
||||
|
||||
def __next__(self):
|
||||
@@ -225,7 +225,12 @@ class LoadImagesAndLabels: # for training
|
||||
img_all = np.ascontiguousarray(img_all, dtype=np.float32) # uint8 to float32
|
||||
img_all /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
|
||||
labels_all = torch.from_numpy(np.concatenate(labels_all, 0))
|
||||
if len(labels_all) > 0:
|
||||
labels_all = np.concatenate(labels_all, 0)
|
||||
else:
|
||||
labels_all = np.zeros((1, 6), dtype='float32')
|
||||
|
||||
labels_all = torch.from_numpy(labels_all)
|
||||
return torch.from_numpy(img_all), labels_all, img_paths, img_shapes
|
||||
|
||||
def __len__(self):
|
||||
|
||||
+1
-1
@@ -40,7 +40,7 @@ def model_info(model):
|
||||
print('\n%5s %38s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
print('%5g %38s %9s %12g %20s %12.3g %12.3g' % (
|
||||
print('%5g %40s %9s %12g %20s %10.3g %10.3g' % (
|
||||
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user