提交 40012bce authored 作者: Razvan Pascanu's avatar Razvan Pascanu

shared variables become sit_sot if are of tensor type

上级 ee763488
......@@ -843,17 +843,31 @@ def scan(fn,
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
if isinstance(new_var.type, tensor.TensorType):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
scan_utils.expand(
tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0),
actual_n_steps))
sit_sot_inner_outputs.append(input.update)
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
......@@ -1041,10 +1055,13 @@ def scan(fn,
nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx]
if pos >= 0:
scan_out_list[pos] = _scan_out_list[idx]
else:
update_map[sit_sot_shared[abs(pos)-1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0:
scan_out_list = None
return (scan_out_list, update_map)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论