提交 40012bce authored 作者: Razvan Pascanu's avatar Razvan Pascanu

shared variables become sit_sot if are of tensor type

上级 ee763488
...@@ -843,17 +843,31 @@ def scan(fn, ...@@ -843,17 +843,31 @@ def scan(fn,
shared_scan_inputs = [] shared_scan_inputs = []
shared_inner_inputs = [] shared_inner_inputs = []
shared_inner_outputs = [] shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs: for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None: if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append(new_var) if isinstance(new_var.type, tensor.TensorType):
shared_scan_inputs.append(input.variable) sit_sot_inner_inputs.append(new_var)
shared_inner_outputs.append(input.update) sit_sot_scan_inputs.append(
givens[input.variable] = new_var scan_utils.expand(
n_shared_outs += 1 tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0),
actual_n_steps))
sit_sot_inner_outputs.append(input.update)
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
## Step 5.4 Outputs with no taps used in the input ## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
...@@ -1041,10 +1055,13 @@ def scan(fn, ...@@ -1041,10 +1055,13 @@ def scan(fn,
nit_sot_rightOrder) nit_sot_rightOrder)
scan_out_list = [None] * len(rightOrder) scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder): for idx, pos in enumerate(rightOrder):
scan_out_list[pos] = _scan_out_list[idx] if pos >= 0:
scan_out_list[pos] = _scan_out_list[idx]
else:
update_map[sit_sot_shared[abs(pos)-1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return (scan_out_list, update_map)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论