diff --git a/models.py b/models.py index cc3093d7..c75b1ec1 100755 --- a/models.py +++ b/models.py @@ -117,15 +117,13 @@ def create_modules(module_defs, img_size, arc): class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): - result = i * torch.sigmoid(i) ctx.save_for_backward(i) - return result + return i * torch.sigmoid(i) @staticmethod def backward(ctx, grad_output): - i = ctx.saved_variables[0] - sigmoid_i = torch.sigmoid(i) - return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) + return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module):