提交 9de8fcaf authored 作者: Razvan Pascanu's avatar Razvan Pascanu

take care of case when n_steps is None

上级 307a28ff
......@@ -187,13 +187,16 @@ def canonical_arguments(sequences,
parameters = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)]
inputs = []
if n_steps is not None:
negative_n_steps = tensor.lt(tensor.as_tensor_variable(n_steps), 0)
for input in to_list(sequences):
if not isinstance(u, dict):
nw_input = tensor.as_tensor_variable(input)
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input = tensor.switch(negative_n_steps, nw_input[::-1],
nw_input)
inputs.append(tensor.as_tensor_variable(nw_input))
elif input.get('taps', True) is None:
......@@ -201,7 +204,7 @@ def canonical_arguments(sequences,
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input = tensor.switch(negative_n_steps, nw_input[::-1],
nw_input)
inputs.append(nw_input)
elif input.get('taps', None):
......@@ -211,7 +214,7 @@ def canonical_arguments(sequences,
if go_backwards:
orig_input = orig_input[::-1]
if n_steps is not None:
orig_input = tensor.switch(n_steps < 0, orig_input[::-1],
orig_input = tensor.switch(negative_n_steps, orig_input[::-1],
org_input)
for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论