提交 ad628f1f authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6452 from Faruk-Ahmed/ARange-gradient

ARange gradient (contd)
......@@ -5530,22 +5530,23 @@ class ARange(Op):
def L_op(self, inputs, outputs, grads):
start, stop, step = inputs
gz, = grads
# start and step affect the output values
# `start` and `step` affect the output values
# but the outputs are integers so there's
# no gradient through them
# stop does not affect the output values,
# just the output shape, so it is disconnected
if (self.dtype in discrete_dtypes) and (step > 0):
return [start.zeros_like(),
# no gradient through them.
# When they are not integers, the gradients are
# as expressed below.
# `stop` does not affect the output values,
# just the output shape, so it is disconnected.
if self.dtype in discrete_dtypes:
return [start.zeros_like(dtype=config.floatX),
DisconnectedType()(),
step.zeros_like()]
step.zeros_like(dtype=config.floatX)]
else:
num_steps_taken = outputs[0].shape[0] # (stop-start)/step
return [gz.sum(dtype=config.floatX),
num_steps_taken = outputs[0].shape[0]
return [gz.sum(),
DisconnectedType()(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(
dtype=config.floatX)]
(gz * arange(num_steps_taken, dtype=self.dtype)).sum()]
def R_op(self, inputs, eval_points):
return [None]
......
......@@ -5917,12 +5917,13 @@ class TestARange(unittest.TestCase):
rng = np.random.RandomState(utt.fetch_seed())
# Due to the random projection, we should not use the exact
# point that shage the shape of the output.
# point that change the shape of the output.
for start, stop, step in [(0, 4.9, 1),
(5.1, 0, -0.5),
(1, 5.1, 0.5)]:
utt.verify_grad(f, [np.float32(start),
np.float32(stop),
np.float32(step)],
utt.verify_grad(f, [np.asarray(start).astype(config.floatX),
np.asarray(stop).astype(config.floatX),
np.asarray(step).astype(config.floatX)],
rng=rng)
def test_integers(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论