提交 ec866914 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

using operators in scalar grads

上级 0d49394d
......@@ -2,6 +2,7 @@
from gof import opt
from elemwise import Broadcast
class InplaceOptimizer(opt.OpSpecificOptimizer):
opclass = Broadcast
......@@ -25,8 +26,15 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
inplace_optimizer = InplaceOptimizer()
# class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt):
# self.
# def find_elemwise_cliques(env):
......
......@@ -45,11 +45,6 @@ class Scalar(ResultBase):
def same_properties(self, other):
return other.dtype == self.dtype
# def mergeable(self, other):
# return getattr(self, 'constant', False) \
# and getattr(other, 'constant', False) \
# and self.data == other.data
def dtype_specs(self):
try:
return {'float32': (float, 'npy_float32', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble'),
......@@ -246,7 +241,7 @@ class Sub(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, neg(gz)
return gz, -gz
class Mul(BinaryScalarOp):
def impl(self, x, y):
......@@ -254,7 +249,7 @@ class Mul(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return mul(y, gz), mul(x, gz)
return gz * y, gz * x
class Div(BinaryScalarOp):
def impl(self, x, y):
......@@ -262,7 +257,7 @@ class Div(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return div(gz, y), neg(div(mul(x, gz), mul(y, y)))
return gz / y, -(gz * x) / (y * y)
class Pow(BinaryScalarOp):
def impl(self, x, y):
......@@ -270,7 +265,8 @@ class Pow(BinaryScalarOp):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )):
return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y)))
return gz * y * x**(y - as_scalar(1)), gz * log(x) * x**y
# return mul(gz, mul(y, pow(x, sub(y, as_scalar(1))))), mul(gz, mul(log(x), pow(x, y)))
class First(BinaryScalarOp):
def impl(self, x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论