提交 c28d2580 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

using operators for all grads in scalar.py

上级 3808dd33
...@@ -267,8 +267,7 @@ class Pow(BinaryScalarOp): ...@@ -267,8 +267,7 @@ class Pow(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz * y * x**(y - as_scalar(1)), gz * log(x) * x**y return gz * y * x**(y - 1), gz * log(x) * x**y
# return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y)))
class First(BinaryScalarOp): class First(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -299,7 +298,7 @@ class Neg(UnaryScalarOp): ...@@ -299,7 +298,7 @@ class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return neg(gz), return -gz,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
...@@ -307,7 +306,7 @@ class Abs(UnaryScalarOp): ...@@ -307,7 +306,7 @@ class Abs(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return mul(gz, sgn(x)), return gz * sgn(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = abs(%(x)s);" % locals() return "%(z)s = abs(%(x)s);" % locals()
...@@ -323,7 +322,7 @@ class Inv(UnaryScalarOp): ...@@ -323,7 +322,7 @@ class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1 / x return 1 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return div(neg(gz), mul(x, x)), return -gz / (x * x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" % locals() return "%(z)s = 1 / %(x)s;" % locals()
...@@ -331,7 +330,7 @@ class Log(UnaryScalarOp): ...@@ -331,7 +330,7 @@ class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return div(gz, x), return gz / x,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
...@@ -339,7 +338,7 @@ class Log2(UnaryScalarOp): ...@@ -339,7 +338,7 @@ class Log2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return div(gz, mul(x, as_scalar(math.log(2.0)))), return gz / (x * math.log(2.0)),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
...@@ -347,7 +346,7 @@ class Exp(UnaryScalarOp): ...@@ -347,7 +346,7 @@ class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return mul(gz, exp(x)), return gz * exp(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals() return "%(z)s = exp(%(x)s);" % locals()
...@@ -355,7 +354,7 @@ class Sqr(UnaryScalarOp): ...@@ -355,7 +354,7 @@ class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return mul(gz, mul(x, as_scalar(2))), return gz * x * 2,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
...@@ -363,7 +362,7 @@ class Sqrt(UnaryScalarOp): ...@@ -363,7 +362,7 @@ class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return math.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return div(mul(gz, as_scalar(0.5)), sqrt(x)), return (gz * 0.5) / sqrt(x),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论