提交 fd628c5a authored 作者: affanv14's avatar affanv14

Added L_op to scalar/basic.py

上级 f3844589
差异被折叠。
...@@ -602,7 +602,7 @@ second dimension ...@@ -602,7 +602,7 @@ second dimension
ograds = [x.zeros_like() for x in outs] ograds = [x.zeros_like() for x in outs]
ograds[idx] = theano.tensor.ones_like(out) ograds[idx] = theano.tensor.ones_like(out)
bgrads = self._bgrad(inputs, ograds) bgrads = self._bgrad(inputs, outs, ograds)
rop_out = None rop_out = None
for jdx, (inp, eval_point) in enumerate(izip(inputs, for jdx, (inp, eval_point) in enumerate(izip(inputs,
...@@ -636,7 +636,7 @@ second dimension ...@@ -636,7 +636,7 @@ second dimension
def L_op(self, inputs, outs, ograds): def L_op(self, inputs, outs, ograds):
# compute grad with respect to broadcasted input # compute grad with respect to broadcasted input
rval = self._bgrad(inputs, ograds) rval = self._bgrad(inputs, outs, ograds)
# TODO: make sure that zeros are clearly identifiable # TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have # to the gradient.grad method when the outputs have
...@@ -684,7 +684,7 @@ second dimension ...@@ -684,7 +684,7 @@ second dimension
return rval return rval
def _bgrad(self, inputs, ograds): def _bgrad(self, inputs, outputs, ograds):
# returns grad, with respect to broadcasted versions of inputs # returns grad, with respect to broadcasted versions of inputs
with change_flags(compute_test_value='off'): with change_flags(compute_test_value='off'):
...@@ -695,7 +695,7 @@ second dimension ...@@ -695,7 +695,7 @@ second dimension
scalar_inputs = list(map(as_scalar, inputs)) scalar_inputs = list(map(as_scalar, inputs))
scalar_ograds = list(map(as_scalar, ograds)) scalar_ograds = list(map(as_scalar, ograds))
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds) scalar_igrads = self.scalar_op.L_op(scalar_inputs, outputs, scalar_ograds)
for igrad in scalar_igrads: for igrad in scalar_igrads:
assert igrad is not None, self.scalar_op assert igrad is not None, self.scalar_op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论