提交 0887bec2 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fixup by using the good num_step used and using DisconnectedType instead of…

Fixup by using the good num_step used and using DisconnectedType instead of zeros (This change what was requested), remove +1. Test another corner case.
上级 06b78bde
...@@ -5527,7 +5527,7 @@ class ARange(Op): ...@@ -5527,7 +5527,7 @@ class ARange(Op):
return [[True], [False], [True]] return [[True], [False], [True]]
def grad(self, inputs, grads): def L_op(self, inputs, outputs, grads):
start, stop, step = inputs start, stop, step = inputs
gz, = grads gz, = grads
# start and step affect the output values # start and step affect the output values
...@@ -5541,10 +5541,11 @@ class ARange(Op): ...@@ -5541,10 +5541,11 @@ class ARange(Op):
DisconnectedType()(), DisconnectedType()(),
step.zeros_like()] step.zeros_like()]
else: else:
num_steps_taken = (stop-start)/step num_steps_taken = outputs[0].shape[0] # (stop-start)/step
return [gz.sum(dtype=config.floatX), return [gz.sum(dtype=config.floatX),
stop.zeros_like(dtype=config.floatX), DisconnectedType()(),
(gz*arange(num_steps_taken+1)).sum(dtype=config.floatX)] (gz * arange(num_steps_taken, dtype=self.dtype)).sum(
dtype=config.floatX)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
......
...@@ -5912,15 +5912,18 @@ class TestARange(unittest.TestCase): ...@@ -5912,15 +5912,18 @@ class TestARange(unittest.TestCase):
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1)) assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
def test_grads(self): def test_grads(self):
start, stop, step = fscalars('start', 'stop', 'step')
def f(start, stop, step): def f(start, stop, step):
return ARange(start.type.dtype)(start, stop, step) return ARange(start.type.dtype)(start, stop, step)
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
utt.verify_grad(f, [np.float32(0), # Due to the random projection, we should not use the exact
np.float32(5), # point that shage the shape of the output.
np.float32(1)], for start, stop, step in [(0, 4.9, 1),
rng=rng) (1, 5.1, 0.5)]:
utt.verify_grad(f, [np.float32(start),
np.float32(stop),
np.float32(step)],
rng=rng)
def test_integers(self): def test_integers(self):
# Test arange constructor, on integer outputs # Test arange constructor, on integer outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论