提交 c44f2c16 authored 作者: Faruk-Ahmed's avatar Faruk-Ahmed

correct gradient for ARange

上级 926dec89
...@@ -5464,6 +5464,7 @@ class ARange(Op): ...@@ -5464,6 +5464,7 @@ class ARange(Op):
inputs = [start, stop, step] inputs = [start, stop, step]
outputs = [tensor(self.dtype, (False,))] outputs = [tensor(self.dtype, (False,))]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
@theano.configparser.change_flags(warn_float64='ignore') @theano.configparser.change_flags(warn_float64='ignore')
...@@ -5523,9 +5524,16 @@ class ARange(Op): ...@@ -5523,9 +5524,16 @@ class ARange(Op):
# no gradient through them # no gradient through them
# stop does not affect the output values, # stop does not affect the output values,
# just the output shape, so it is disconnected # just the output shape, so it is disconnected
return [start.zeros_like(),
DisconnectedType()(), if (self.dtype in discrete_dtypes) and (step > 0):
step.zeros_like()] return [start.zeros_like(),
DisconnectedType()(),
step.zeros_like()]
else:
num_steps_taken = (stop-start)/step
return [gz.sum(dtype=config.floatX),
stop.zeros_like(dtype=config.floatX),
(gz*arange(num_steps_taken+1)).sum(dtype=config.floatX)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
......
...@@ -5868,6 +5868,17 @@ class TestARange(unittest.TestCase): ...@@ -5868,6 +5868,17 @@ class TestARange(unittest.TestCase):
assert np.all(f(10, 2, 2) == np.arange(10, 2, 2)) assert np.all(f(10, 2, 2) == np.arange(10, 2, 2))
assert np.all(f(0, 0, 1) == np.arange(0, 0, 1)) assert np.all(f(0, 0, 1) == np.arange(0, 0, 1))
def test_grads(self):
start, stop, step = fscalars('start', 'stop', 'step')
def f(start, stop, step):
return ARange(start.type.dtype)(start, stop, step)
rng = np.random.RandomState(utt.fetch_seed())
utt.verify_grad(f, [np.float32(0),
np.float32(5),
np.float32(1)],
rng=rng)
def test_integers(self): def test_integers(self):
# Test arange constructor, on integer outputs # Test arange constructor, on integer outputs
start, stop, step = iscalars('start', 'stop', 'step') start, stop, step = iscalars('start', 'stop', 'step')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论