提交 909e5a81 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

At Arnaud suggestion, I've altered scan such that if you provide a number of

steps, scan does that number of steps or raises an error ( i.e. it does not check the length of the sequences and pick the largest number of steps). It makes things more optimizable, in terms of scan mergings. It is also logical, if you provided the number of steps, you do expect those number of steps to be executed. I might need to check documentation to make sure it is in sync with scan.
上级 496d2727
......@@ -467,9 +467,14 @@ def scan( fn
'n_steps argument of scan or provide an input '
'sequence')
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = tensor.minimum(actual_n_steps, contestant)
# If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough )
if scan_utils.check_NaN_Inf_None(n_steps):
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = tensor.minimum(actual_n_steps, contestant)
else:
actual_n_steps = tensor.as_tensor(n_steps)
# Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs):
......
......@@ -381,12 +381,26 @@ class Scan(Op):
# 1. Unzip the number of steps and sequences. If number of steps is
# negative flip sequences around, and make n_steps positive
n_steps = args[0]
seqs = []
if n_steps < 0:
n_steps = abs(n_steps)
seqs = [ seq[::-1] for seq in args[1:self.seqs_arg_offset]]
for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1+idx],
seq.shape)
seqs.append(seq[::-1])
else:
seqs = args[1:self.seqs_arg_offset]
for idx, seq in enumerate(args[1:self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
raise ValueError(('Sequence is shorter then the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
node.inputs[1+idx],
seq.shape)
seqs.append(seq)
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论