提交 c28d2580 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

using operators for all grads in scalar.py

上级 3808dd33
......@@ -267,8 +267,7 @@ class Pow(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )):
return gz * y * x**(y - as_scalar(1)), gz * log(x) * x**y
# return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y)))
return gz * y * x**(y - 1), gz * log(x) * x**y
class First(BinaryScalarOp):
def impl(self, x, y):
......@@ -299,7 +298,7 @@ class Neg(UnaryScalarOp):
def impl(self, x):
return -x
def grad(self, (x, ), (gz, )):
return neg(gz),
return -gz,
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals()
......@@ -307,7 +306,7 @@ class Abs(UnaryScalarOp):
def impl(self, x):
return numpy.abs(x)
def grad(self, (x, ), (gz, )):
return mul(gz, sgn(x)),
return gz * sgn(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = abs(%(x)s);" % locals()
......@@ -323,7 +322,7 @@ class Inv(UnaryScalarOp):
def impl(self, x):
return 1 / x
def grad(self, (x, ), (gz, )):
return div(neg(gz), mul(x, x)),
return -gz / (x * x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" % locals()
......@@ -331,7 +330,7 @@ class Log(UnaryScalarOp):
def impl(self, x):
return math.log(x)
def grad(self, (x, ), (gz, )):
return div(gz, x),
return gz / x,
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals()
......@@ -339,7 +338,7 @@ class Log2(UnaryScalarOp):
def impl(self, x):
return numpy.log2(x)
def grad(self, (x, ), (gz, )):
return div(gz, mul(x, as_scalar(math.log(2.0)))),
return gz / (x * math.log(2.0)),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals()
......@@ -347,7 +346,7 @@ class Exp(UnaryScalarOp):
def impl(self, x):
return math.exp(x)
def grad(self, (x, ), (gz, )):
return mul(gz, exp(x)),
return gz * exp(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals()
......@@ -355,7 +354,7 @@ class Sqr(UnaryScalarOp):
def impl(self, x):
return x*x
def grad(self, (x, ), (gz, )):
return mul(gz, mul(x, as_scalar(2))),
return gz * x * 2,
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals()
......@@ -363,7 +362,7 @@ class Sqrt(UnaryScalarOp):
def impl(self, x):
return math.sqrt(x)
def grad(self, (x, ), (gz, )):
return div(mul(gz, as_scalar(0.5)), sqrt(x)),
return (gz * 0.5) / sqrt(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论