提交 32ce6710 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

collect all arguments in a single inputs/ouputs list

The distinction between different types of inputs/outputs is not really necessary anymore, so in order to simplify the code I just create a single list of all inputs
上级 796e4cd2
...@@ -443,24 +443,20 @@ def scan(fn, ...@@ -443,24 +443,20 @@ def scan(fn,
non_numeric_output_states.append(update_d[sv]) non_numeric_output_states.append(update_d[sv])
original_non_numeric_shared_variables.append(sv) original_non_numeric_shared_variables.append(sv)
# 5.2 Collect and order inputs of the inner function # 5.2 Collect inputs/outputs of the inner function
input_states_outer = [] inputs = []
output_states = []
memory_buffers_for_outputs = []
outputs = [] outputs = []
for n, mintap in enumerate(mintaps): for n, mintap in enumerate(mintaps):
if mintap != 0: if mintap != 0:
input_state = states_and_outputs_info[n]['initial'] input_state = states_and_outputs_info[n]['initial']
input_states_outer.append(input_state) inputs.append(input_state)
output_states.append( outputs.append(
tensor.set_subtensor(input_state[(t + 1) % lengths[n]], tensor.set_subtensor(input_state[(t + mintap) % lengths[n]],
states_and_outputs[n])) states_and_outputs[n]))
else: else:
output = scan_utils.allocate_memory( mem_buffer = scan_utils.allocate_memory(
T, states_and_outputs_info[n], states_and_outputs[n]) T, states_and_outputs_info[n], states_and_outputs[n])
memory_buffers_for_outputs.append(output) inputs.append(output)
outputs.append( outputs.append(
tensor.set_subtensor(output[t % lengths[n]], tensor.set_subtensor(output[t % lengths[n]],
states_and_outputs[n]) states_and_outputs[n])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论