提交 ec866914 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

using operators in scalar grads

上级 0d49394d
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from gof import opt from gof import opt
from elemwise import Broadcast from elemwise import Broadcast
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(opt.OpSpecificOptimizer):
opclass = Broadcast opclass = Broadcast
...@@ -25,6 +26,13 @@ class InplaceOptimizer(opt.OpSpecificOptimizer): ...@@ -25,6 +26,13 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
inplace_optimizer = InplaceOptimizer() inplace_optimizer = InplaceOptimizer()
# class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt):
# self.
# def find_elemwise_cliques(env):
......
...@@ -45,11 +45,6 @@ class Scalar(ResultBase): ...@@ -45,11 +45,6 @@ class Scalar(ResultBase):
def same_properties(self, other): def same_properties(self, other):
return other.dtype == self.dtype return other.dtype == self.dtype
# def mergeable(self, other):
# return getattr(self, 'constant', False) \
# and getattr(other, 'constant', False) \
# and self.data == other.data
def dtype_specs(self): def dtype_specs(self):
try: try:
return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'), return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
...@@ -246,7 +241,7 @@ class Sub(BinaryScalarOp): ...@@ -246,7 +241,7 @@ class Sub(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, neg(gz) return gz, -gz
class Mul(BinaryScalarOp): class Mul(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -254,7 +249,7 @@ class Mul(BinaryScalarOp): ...@@ -254,7 +249,7 @@ class Mul(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals() return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return mul(y, gz), mul(x, gz) return gz * y, gz * x
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -262,7 +257,7 @@ class Div(BinaryScalarOp): ...@@ -262,7 +257,7 @@ class Div(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return div(gz, y), neg(div(mul(x, gz), mul(y, y))) return gz / y, -(gz * x) / (y * y)
class Pow(BinaryScalarOp): class Pow(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -270,7 +265,8 @@ class Pow(BinaryScalarOp): ...@@ -270,7 +265,8 @@ class Pow(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y))) return gz * y * x**(y - as_scalar(1)), gz * log(x) * x**y
# return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y)))
class First(BinaryScalarOp): class First(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论