multi_thread dataloader

This commit is contained in:
Glenn Jocher
2019-03-21 14:48:40 +02:00
parent be38caf284
commit 70fe2204b4
4 changed files with 21 additions and 14 deletions
+8 -3
View File
@@ -7,7 +7,6 @@ import cv2
import numpy as np
import torch
# from torch.utils.data import Dataset
from utils.utils import xyxy2xywh
@@ -114,10 +113,11 @@ class LoadImagesAndLabels: # for training
def __getitem__(self, index):
imgs, labels0, img_paths, img_shapes = self.load_images(index, index + 1)
labels0[:,0] = index % self.batch_size
labels0[:, 0] = index % self.batch_size
labels = torch.zeros(100, 6)
labels[:min(len(labels0), 100)] = labels0 # max 100 labels per image
return imgs.squeeze(0), labels, img_paths, img_shapes
def __next__(self):
@@ -225,7 +225,12 @@ class LoadImagesAndLabels: # for training
img_all = np.ascontiguousarray(img_all, dtype=np.float32) # uint8 to float32
img_all /= 255.0 # 0 - 255 to 0.0 - 1.0
labels_all = torch.from_numpy(np.concatenate(labels_all, 0))
if len(labels_all) > 0:
labels_all = np.concatenate(labels_all, 0)
else:
labels_all = np.zeros((1, 6), dtype='float32')
labels_all = torch.from_numpy(labels_all)
return torch.from_numpy(img_all), labels_all, img_paths, img_shapes
def __len__(self):
+1 -1
View File
@@ -40,7 +40,7 @@ def model_info(model):
print('\n%5s %38s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %38s %9s %12g %20s %12.3g %12.3g' % (
print('%5g %40s %9s %12g %20s %10.3g %10.3g' % (
i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))