greenhouse/utils/activations.py
Glenn Jocher 76807fae71
YOLOv5 Forward Compatibility Update (#1569)
* YOLOv5 forward compatibility update

* add data dir

* ci test yolov3

* update build_targets()

* update build_targets()

* update build_targets()

* update yolov3-spp.yaml

* add yolov3-tiny.yaml

* add yolov3-tiny.yaml

* Update yolov3-tiny.yaml

* thop bug fix

* Detection() device bug fix

* Use torchvision.ops.nms()

* Remove redundant download mirror

* CI tests with yolov3-tiny

* Update README.md

* Synch train and test iou_thresh

* update requirements.txt

* Cat apriori autolabels

* Confusion matrix

* Autosplit

* Autosplit

* Update README.md

* AP no plot

* Update caching

* Update caching

* Caching bug fix

* --image-weights bug fix

* datasets bug fix

* mosaic plots bug fix

* plot_study

* boxes.max()

* boxes.max()

* boxes.max()

* boxes.max()

* boxes.max()

* boxes.max()

* update

* Update README

* Update README

* Update README.md

* Update README.md

* results png

* Update README

* Targets scaling bug fix

* update plot_study

* update plot_study

* update plot_study

* update plot_study

* Targets scaling bug fix

* Finish Readme.md

* Finish Readme.md

* Finish Readme.md

* Update README.md

* Creado con Colaboratory
2020-11-26 20:24:00 +01:00

73 lines
2.1 KiB
Python

# Activation functions
import torch
import torch.nn as nn
import torch.nn.functional as F
# Swish https://arxiv.org/pdf/1905.02244.pdf ---------------------------------------------------------------------------
class Swish(nn.Module): #
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
@staticmethod
def forward(x):
# return x * F.hardsigmoid(x) # for torchscript and CoreML
return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX
class MemoryEfficientSwish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * torch.sigmoid(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
return grad_output * (sx * (1 + x * (1 - sx)))
def forward(self, x):
return self.F.apply(x)
# Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
class Mish(nn.Module):
@staticmethod
def forward(x):
return x * F.softplus(x).tanh()
class MemoryEfficientMish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
def forward(self, x):
return self.F.apply(x)
# FReLU https://arxiv.org/abs/2007.11824 -------------------------------------------------------------------------------
class FReLU(nn.Module):
def __init__(self, c1, k=3): # ch_in, kernel
super().__init__()
self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1)
self.bn = nn.BatchNorm2d(c1)
def forward(self, x):
return torch.max(x, self.bn(self.conv(x)))