提交 ad628f1f authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6452 from Faruk-Ahmed/ARange-gradient

ARange gradient (contd)
...@@ -5530,22 +5530,23 @@ class ARange(Op): ...@@ -5530,22 +5530,23 @@ class ARange(Op):
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
start, stop, step = inputs start, stop, step = inputs
gz, = grads gz, = grads
# start and step affect the output values # `start` and `step` affect the output values
# but the outputs are integers so there's # but the outputs are integers so there's
# no gradient through them # no gradient through them.
# stop does not affect the output values, # When they are not integers, the gradients are
# just the output shape, so it is disconnected # as expressed below.
# `stop` does not affect the output values,
if (self.dtype in discrete_dtypes) and (step > 0): # just the output shape, so it is disconnected.
return [start.zeros_like(),
if self.dtype in discrete_dtypes:
return [start.zeros_like(dtype=config.floatX),
DisconnectedType()(), DisconnectedType()(),
step.zeros_like()] step.zeros_like(dtype=config.floatX)]
else: else:
num_steps_taken = outputs[0].shape[0] # (stop-start)/step num_steps_taken = outputs[0].shape[0]
return [gz.sum(dtype=config.floatX), return [gz.sum(),
DisconnectedType()(), DisconnectedType()(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum( (gz * arange(num_steps_taken, dtype=self.dtype)).sum()]
dtype=config.floatX)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
......
...@@ -5917,12 +5917,13 @@ class TestARange(unittest.TestCase): ...@@ -5917,12 +5917,13 @@ class TestARange(unittest.TestCase):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
# Due to the random projection, we should not use the exact # Due to the random projection, we should not use the exact
# point that shage the shape of the output. # point that change the shape of the output.
for start, stop, step in [(0, 4.9, 1), for start, stop, step in [(0, 4.9, 1),
(5.1, 0, -0.5),
(1, 5.1, 0.5)]: (1, 5.1, 0.5)]:
utt.verify_grad(f, [np.float32(start), utt.verify_grad(f, [np.asarray(start).astype(config.floatX),
np.float32(stop), np.asarray(stop).astype(config.floatX),
np.float32(step)], np.asarray(step).astype(config.floatX)],
rng=rng) rng=rng)
def test_integers(self): def test_integers(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论