提交 222ac1c4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

invert the sequences (if needed) before anything else

上级 08e7974b
......@@ -189,12 +189,30 @@ def canonical_arguments(sequences,
inputs = []
for input in to_list(sequences):
if not isinstance(u, dict):
inputs.append(tensor.as_tensor_variable(input))
nw_input = tensor.as_tensor_variable(input)
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input)
inputs.append(tensor.as_tensor_variable(nw_input))
elif input.get('taps', True) is None:
inputs.append(tensor.as_tensor_variable(input['input']))
nw_input = tensor.as_tensor_variable(input['input'])
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input)
inputs.append(nw_input)
elif input.get('taps', None):
mintap = numpy.min(input['taps'])
maxtap = numpy.max(input['taps'])
orig_input = tensor.as_tensor_variable(input['input'])
if go_backwards:
orig_input = orig_input[::-1]
if n_steps is not None:
orig_input = tensor.switch(n_steps < 0, orig_input[::-1],
org_input)
for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
......@@ -202,7 +220,7 @@ def canonical_arguments(sequences,
offset = abs(maxtap)
else:
offset = 0
nw_input = tensor.as_tensor_variable(input['input'])
nw_input = orig_input
if maxtap == mintap and maxtap != 0:
nw_input = nw_input[:abs(maxtap)]
elif maxtap - k != 0:
......@@ -210,8 +228,6 @@ def canonical_arguments(sequences,
-(maxtap - k)]
else:
nw_input = nw_input[offset + k - mintap:]
if go_backwards:
nw_input = nw_input[::-1]
inputs.append(nw_input)
else:
raise ValueError('Provided sequence makes no sense', str(input))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论