提交 80f150b2 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix a bug of grad reported by Michael Forbes

The bug is that we do not clip the sequences to n_steps, meaning that when we revert them for computing the gradient we mess up the gradient.
上级 5dd4c64a
...@@ -538,6 +538,7 @@ def scan( fn ...@@ -538,6 +538,7 @@ def scan( fn
if getattr(seq['input'],'name', None) is not None: if getattr(seq['input'],'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]'%k nw_seq.name = seq['input'].name + '[%d:]'%k
scan_seqs = [ seq[:actual_n_steps] for seq in scan_seqs]
# Conventions : # Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided # mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function ) # by the gradient function )
......
...@@ -2481,6 +2481,27 @@ class T_Scan(unittest.TestCase): ...@@ -2481,6 +2481,27 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 1 assert len(lssc) == 1
def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes
# This test assures that we clip the sequences to n_steps before
# computing the gradient (so that when we reverse them we actually
# get the right values in
c = theano.tensor.vector('c')
x = theano.tensor.scalar('x')
_max_coefficients_supported = 100
full_range = theano.tensor.arange(_max_coefficients_supported)
components, updates = theano.scan(fn=lambda coeff, power,
free_var:
coeff * (free_var ** power),
outputs_info=None,
sequences=[c, full_range],
non_sequences=x)
P = components.sum()
dP = theano.tensor.grad(P, x)
tf = theano.function([c,x], dP)
assert tf([1.0,2.0,-3.0,4.0], 2.0) == 38
def test_return_steps(self): def test_return_steps(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(size = (2,), low = -5.,high = 5.)) vW_in2 = asarrayX(rng.uniform(size = (2,), low = -5.,high = 5.))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论