提交 32ce6710 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

collect all arguments in a single inputs/ouputs list

The distinction between different types of inputs/outputs is not really necessary anymore, so in order to simplify the code I just create a single list of all inputs
上级 796e4cd2
......@@ -443,24 +443,20 @@ def scan(fn,
non_numeric_output_states.append(update_d[sv])
original_non_numeric_shared_variables.append(sv)
# 5.2 Collect and order inputs of the inner function
input_states_outer = []
output_states = []
memory_buffers_for_outputs = []
# 5.2 Collect inputs/outputs of the inner function
inputs = []
outputs = []
for n, mintap in enumerate(mintaps):
if mintap != 0:
input_state = states_and_outputs_info[n]['initial']
input_states_outer.append(input_state)
output_states.append(
tensor.set_subtensor(input_state[(t + 1) % lengths[n]],
inputs.append(input_state)
outputs.append(
tensor.set_subtensor(input_state[(t + mintap) % lengths[n]],
states_and_outputs[n]))
else:
output = scan_utils.allocate_memory(
mem_buffer = scan_utils.allocate_memory(
T, states_and_outputs_info[n], states_and_outputs[n])
memory_buffers_for_outputs.append(output)
inputs.append(output)
outputs.append(
tensor.set_subtensor(output[t % lengths[n]],
states_and_outputs[n])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论