提交 9de8fcaf authored 作者: Razvan Pascanu's avatar Razvan Pascanu

take care of case when n_steps is None

上级 307a28ff
...@@ -187,13 +187,16 @@ def canonical_arguments(sequences, ...@@ -187,13 +187,16 @@ def canonical_arguments(sequences,
parameters = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)] parameters = [tensor.as_tensor_variable(x) for x in to_list(non_sequences)]
inputs = [] inputs = []
if n_steps is not None:
negative_n_steps = tensor.lt(tensor.as_tensor_variable(n_steps), 0)
for input in to_list(sequences): for input in to_list(sequences):
if not isinstance(u, dict): if not isinstance(u, dict):
nw_input = tensor.as_tensor_variable(input) nw_input = tensor.as_tensor_variable(input)
if go_backwards: if go_backwards:
nw_input = nw_input[::-1] nw_input = nw_input[::-1]
if n_steps is not None: if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1], nw_input = tensor.switch(negative_n_steps, nw_input[::-1],
nw_input) nw_input)
inputs.append(tensor.as_tensor_variable(nw_input)) inputs.append(tensor.as_tensor_variable(nw_input))
elif input.get('taps', True) is None: elif input.get('taps', True) is None:
...@@ -201,7 +204,7 @@ def canonical_arguments(sequences, ...@@ -201,7 +204,7 @@ def canonical_arguments(sequences,
if go_backwards: if go_backwards:
nw_input = nw_input[::-1] nw_input = nw_input[::-1]
if n_steps is not None: if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1], nw_input = tensor.switch(negative_n_steps, nw_input[::-1],
nw_input) nw_input)
inputs.append(nw_input) inputs.append(nw_input)
elif input.get('taps', None): elif input.get('taps', None):
...@@ -211,7 +214,7 @@ def canonical_arguments(sequences, ...@@ -211,7 +214,7 @@ def canonical_arguments(sequences,
if go_backwards: if go_backwards:
orig_input = orig_input[::-1] orig_input = orig_input[::-1]
if n_steps is not None: if n_steps is not None:
orig_input = tensor.switch(n_steps < 0, orig_input[::-1], orig_input = tensor.switch(negative_n_steps, orig_input[::-1],
org_input) org_input)
for k in input['taps']: for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论