提交 f21bd7d3 authored 作者: wazeerzulfikar's avatar wazeerzulfikar 提交者: abergeron

Changed grad() to L_op in given elemwise (#5431)

上级 d9e51499
...@@ -634,11 +634,7 @@ second dimension ...@@ -634,11 +634,7 @@ second dimension
return [[True for output in node.outputs] for ipt in node.inputs] return [[True for output in node.outputs] for ipt in node.inputs]
def grad(self, inputs, ograds): def L_op(self, inputs, outs, ograds):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
# compute grad with respect to broadcasted input # compute grad with respect to broadcasted input
rval = self._bgrad(inputs, ograds) rval = self._bgrad(inputs, ograds)
...@@ -1927,12 +1923,10 @@ class Sum(CAReduceDtype): ...@@ -1927,12 +1923,10 @@ class Sum(CAReduceDtype):
str(self.acc_dtype) str(self.acc_dtype)
) )
def grad(self, inp, grads): def L_op(self, inp, out, grads):
x, = inp x, = inp
out = self(*inp) if out[0].dtype not in theano.tensor.continuous_dtypes:
if out.dtype not in theano.tensor.continuous_dtypes:
return [x.zeros_like(dtype=theano.config.floatX)] return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads gz, = grads
...@@ -1985,7 +1979,7 @@ class Prod(CAReduceDtype): ...@@ -1985,7 +1979,7 @@ class Prod(CAReduceDtype):
if 'no_zeros_in_input' not in dct: if 'no_zeros_in_input' not in dct:
self.no_zeros_in_input = False self.no_zeros_in_input = False
def grad(self, inp, grads): def L_op(self, inp, out, grads):
""" """
The grad of this Op could be very easy, if it is was not for the case The grad of this Op could be very easy, if it is was not for the case
where zeros are present in a given "group" (ie. elements reduced where zeros are present in a given "group" (ie. elements reduced
...@@ -2034,9 +2028,7 @@ class Prod(CAReduceDtype): ...@@ -2034,9 +2028,7 @@ class Prod(CAReduceDtype):
prod_in, = inp prod_in, = inp
gz, = grads gz, = grads
out = self(*inp) if (out[0].dtype in theano.tensor.discrete_dtypes or
if (out.dtype in theano.tensor.discrete_dtypes or
self.acc_dtype in theano.tensor.discrete_dtypes): self.acc_dtype in theano.tensor.discrete_dtypes):
# There is an int conversion in the way # There is an int conversion in the way
return [prod_in.zeros_like(dtype=theano.config.floatX)] return [prod_in.zeros_like(dtype=theano.config.floatX)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论