updates
This commit is contained in:
@@ -720,6 +720,46 @@ def print_mutation(hyp, results, bucket=''):
|
||||
os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
|
||||
|
||||
|
||||
def apply_classifier(x, model, img, im0):
|
||||
# applies a second stage classifier to yolo outputs
|
||||
|
||||
for i, d in enumerate(x): # per image
|
||||
if d is not None and len(d):
|
||||
d = d.clone()
|
||||
|
||||
# Reshape and pad cutouts
|
||||
b = xyxy2xywh(d[:, :4]) # boxes
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
|
||||
b[:, 2:] = b[:, 2:] * 1.0 + 0 # pad
|
||||
d[:, :4] = xywh2xyxy(b).long()
|
||||
|
||||
# Rescale boxes from img_size to im0 size
|
||||
scale_coords(img.shape[2:], d[:, :4], im0.shape)
|
||||
|
||||
# Classes
|
||||
pred_cls1 = d[:, 6].long()
|
||||
ims = []
|
||||
j = 0
|
||||
for a in d: # per item
|
||||
j += 1
|
||||
cutout = im0[int(a[1]):int(a[3]), int(a[0]):int(a[2])]
|
||||
im = cv2.resize(cutout, (128, 128)) # BGR
|
||||
cv2.imwrite('test%i.jpg' % j, cutout)
|
||||
|
||||
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
im = np.expand_dims(im, axis=0) # add batch dim
|
||||
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
|
||||
im /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
ims.append(im)
|
||||
|
||||
ims = torch.Tensor(np.concatenate(ims, 0)) # to torch
|
||||
pred_cls2 = model(ims).argmax(1) # classifier prediction
|
||||
|
||||
# x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def fitness(x):
|
||||
# Returns fitness (for use with results.txt or evolve.txt)
|
||||
return x[:, 2] * 0.8 + x[:, 3] * 0.2 # weighted mAP and F1 combination
|
||||
|
||||
Reference in New Issue
Block a user