提交 f21bd7d3 authored 作者: wazeerzulfikar's avatar wazeerzulfikar 提交者: abergeron

Changed grad() to L_op in given elemwise (#5431)

上级 d9e51499
......@@ -634,11 +634,7 @@ second dimension
return [[True for output in node.outputs] for ipt in node.inputs]
def grad(self, inputs, ograds):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
def L_op(self, inputs, outs, ograds):
# compute grad with respect to broadcasted input
rval = self._bgrad(inputs, ograds)
......@@ -1927,12 +1923,10 @@ class Sum(CAReduceDtype):
str(self.acc_dtype)
)
def grad(self, inp, grads):
def L_op(self, inp, out, grads):
x, = inp
out = self(*inp)
if out.dtype not in theano.tensor.continuous_dtypes:
if out[0].dtype not in theano.tensor.continuous_dtypes:
return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads
......@@ -1985,7 +1979,7 @@ class Prod(CAReduceDtype):
if 'no_zeros_in_input' not in dct:
self.no_zeros_in_input = False
def grad(self, inp, grads):
def L_op(self, inp, out, grads):
"""
The grad of this Op could be very easy, if it is was not for the case
where zeros are present in a given "group" (ie. elements reduced
......@@ -2034,9 +2028,7 @@ class Prod(CAReduceDtype):
prod_in, = inp
gz, = grads
out = self(*inp)
if (out.dtype in theano.tensor.discrete_dtypes or
if (out[0].dtype in theano.tensor.discrete_dtypes or
self.acc_dtype in theano.tensor.discrete_dtypes):
# There is an int conversion in the way
return [prod_in.zeros_like(dtype=theano.config.floatX)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论