提交 222ac1c4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

invert the sequences (if needed) before anything else

上级 08e7974b
...@@ -189,12 +189,30 @@ def canonical_arguments(sequences, ...@@ -189,12 +189,30 @@ def canonical_arguments(sequences,
inputs = [] inputs = []
for input in to_list(sequences): for input in to_list(sequences):
if not isinstance(u, dict): if not isinstance(u, dict):
inputs.append(tensor.as_tensor_variable(input)) nw_input = tensor.as_tensor_variable(input)
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input)
inputs.append(tensor.as_tensor_variable(nw_input))
elif input.get('taps', True) is None: elif input.get('taps', True) is None:
inputs.append(tensor.as_tensor_variable(input['input'])) nw_input = tensor.as_tensor_variable(input['input'])
if go_backwards:
nw_input = nw_input[::-1]
if n_steps is not None:
nw_input = tensor.switch(n_steps < 0, nw_input[::-1],
nw_input)
inputs.append(nw_input)
elif input.get('taps', None): elif input.get('taps', None):
mintap = numpy.min(input['taps']) mintap = numpy.min(input['taps'])
maxtap = numpy.max(input['taps']) maxtap = numpy.max(input['taps'])
orig_input = tensor.as_tensor_variable(input['input'])
if go_backwards:
orig_input = orig_input[::-1]
if n_steps is not None:
orig_input = tensor.switch(n_steps < 0, orig_input[::-1],
org_input)
for k in input['taps']: for k in input['taps']:
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
# seq[i-k] # seq[i-k]
...@@ -202,7 +220,7 @@ def canonical_arguments(sequences, ...@@ -202,7 +220,7 @@ def canonical_arguments(sequences,
offset = abs(maxtap) offset = abs(maxtap)
else: else:
offset = 0 offset = 0
nw_input = tensor.as_tensor_variable(input['input']) nw_input = orig_input
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_input = nw_input[:abs(maxtap)] nw_input = nw_input[:abs(maxtap)]
elif maxtap - k != 0: elif maxtap - k != 0:
...@@ -210,8 +228,6 @@ def canonical_arguments(sequences, ...@@ -210,8 +228,6 @@ def canonical_arguments(sequences,
-(maxtap - k)] -(maxtap - k)]
else: else:
nw_input = nw_input[offset + k - mintap:] nw_input = nw_input[offset + k - mintap:]
if go_backwards:
nw_input = nw_input[::-1]
inputs.append(nw_input) inputs.append(nw_input)
else: else:
raise ValueError('Provided sequence makes no sense', str(input)) raise ValueError('Provided sequence makes no sense', str(input))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论