updates
This commit is contained in:
@@ -36,6 +36,7 @@ def train(
|
||||
|
||||
# Get dataloader
|
||||
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
||||
# dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size, num_workers=0)
|
||||
|
||||
lr0 = 0.001 # initial learning rate
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
@@ -81,7 +82,7 @@ def train(
|
||||
# Start training
|
||||
t0 = time.time()
|
||||
model_info(model)
|
||||
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
||||
n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
epoch += start_epoch
|
||||
|
||||
Reference in New Issue
Block a user