提交 e4e5d824 authored 作者: Glexin's avatar Glexin 提交者: Pascal Lamblin

Fix one step scan use wrong seq slice bug (#6598)

* Fix bug of incorrect seq slice in scan with one step.
上级 6339ba19
...@@ -490,6 +490,12 @@ def scan(fn, ...@@ -490,6 +490,12 @@ def scan(fn,
# go through the indicated slice # go through the indicated slice
mintap = np.min(seq['taps']) mintap = np.min(seq['taps'])
maxtap = np.max(seq['taps']) maxtap = np.max(seq['taps'])
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]. For the purposes of cutting the sequences, we
# need to pretend tap 0 is used to avoid cutting the sequences
# too long if the taps are all lower or all higher than 0.
maxtap_proxy = max(maxtap, 0)
mintap_proxy = min(mintap, 0)
for k in seq['taps']: for k in seq['taps']:
# create one slice of the input # create one slice of the input
# Later on, if we decide not to use scan because we are # Later on, if we decide not to use scan because we are
...@@ -500,9 +506,9 @@ def scan(fn, ...@@ -500,9 +506,9 @@ def scan(fn,
# If not we need to use copies, that will be replaced at # If not we need to use copies, that will be replaced at
# each frame by the corresponding slice # each frame by the corresponding slice
actual_slice = seq['input'][k - mintap] actual_slice = seq['input'][k - mintap_proxy]
_seq_val = tensor.as_tensor_variable(seq['input']) _seq_val = tensor.as_tensor_variable(seq['input'])
_seq_val_slice = _seq_val[k - mintap] _seq_val_slice = _seq_val[k - mintap_proxy]
nw_slice = _seq_val_slice.type() nw_slice = _seq_val_slice.type()
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
...@@ -529,12 +535,6 @@ def scan(fn, ...@@ -529,12 +535,6 @@ def scan(fn,
nw_name = seq['input'].name + '[t%d]' % k nw_name = seq['input'].name + '[t%d]' % k
nw_slice.name = nw_name nw_slice.name = nw_name
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]. For the purposes of cutting the sequences, we
# need to pretend tap 0 is used to avoid cutting the sequences
# too long if the taps are all lower or all higher than 0.
maxtap_proxy = max(maxtap, 0)
mintap_proxy = min(mintap, 0)
start = (k - mintap_proxy) start = (k - mintap_proxy)
nw_name = None nw_name = None
if k == maxtap_proxy: if k == maxtap_proxy:
......
...@@ -5738,3 +5738,23 @@ def test_condition_hidden_inp(): ...@@ -5738,3 +5738,23 @@ def test_condition_hidden_inp():
outputs=rs) outputs=rs)
_sum, total_steps = f(100, 100) _sum, total_steps = f(100, 100)
def test_mintap_onestep():
seq = theano.tensor.ivector("seq")
seq_info = dict(input=seq, taps=[2])
def accum(seq_t, prev_sum):
new_sum = prev_sum + seq_t
return new_sum
rs, updates = theano.scan(fn=accum,
sequences=seq_info,
outputs_info=0,
n_steps=1)
f = theano.function(inputs=[seq],
outputs=rs)
_seq = np.arange(20).astype("int32")
_sum = f(_seq)
print("sum %f" % _sum)
assert _sum == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论