提交 a5927548 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6447 from Faruk-Ahmed/ARange-gradient

correct gradient for ARange
......@@ -5475,6 +5475,7 @@ class ARange(Op):
inputs = [start, stop, step]
outputs = [tensor(self.dtype, (False,))]
return Apply(self, inputs, outputs)
@theano.configparser.change_flags(warn_float64='ignore')
......@@ -5526,7 +5527,7 @@ class ARange(Op):
return [[True], [False], [True]]
def grad(self, inputs, grads):
def L_op(self, inputs, outputs, grads):
start, stop, step = inputs
gz, = grads
# start and step affect the output values
......@@ -5534,9 +5535,17 @@ class ARange(Op):
# no gradient through them
# stop does not affect the output values,
# just the output shape, so it is disconnected
return [start.zeros_like(),
DisconnectedType()(),
step.zeros_like()]
if (self.dtype in discrete_dtypes) and (step > 0):
return [start.zeros_like(),
DisconnectedType()(),
step.zeros_like()]
else:
num_steps_taken = outputs[0].shape[0] # (stop-start)/step
return [gz.sum(dtype=config.floatX),
DisconnectedType()(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(
dtype=config.floatX)]
def R_op(self, inputs, eval_points):
return [None]
......
......@@ -5911,6 +5911,20 @@ class TestARange(unittest.TestCase):
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
def test_grads(self):
def f(start, stop, step):
return ARange(start.type.dtype)(start, stop, step)
rng = np.random.RandomState(utt.fetch_seed())
# Due to the random projection, we should not use the exact
# point that shage the shape of the output.
for start, stop, step in [(0, 4.9, 1),
(1, 5.1, 0.5)]:
utt.verify_grad(f, [np.float32(start),
np.float32(stop),
np.float32(step)],
rng=rng)
def test_integers(self):
# Test arange constructor, on integer outputs
start, stop, step = iscalars('start', 'stop', 'step')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论