提交 5001661d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added comment about negative index

上级 c15933f2
...@@ -857,6 +857,10 @@ def scan(fn, ...@@ -857,6 +857,10 @@ def scan(fn,
tensor.shape_padleft(input.variable), 0), tensor.shape_padleft(input.variable), 0),
actual_n_steps)) actual_n_steps))
sit_sot_inner_outputs.append(input.update) sit_sot_inner_outputs.append(input.update)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan. Its
# absolute value (-1 in case is negative) represents an index
sit_sot_rightOrder.append(-1 - len(sit_sot_shared)) sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable) sit_sot_shared.append(input.variable)
givens[input.variable] = new_var givens[input.variable] = new_var
...@@ -1058,7 +1062,11 @@ def scan(fn, ...@@ -1058,7 +1062,11 @@ def scan(fn,
if pos >= 0: if pos >= 0:
scan_out_list[pos] = _scan_out_list[idx] scan_out_list[pos] = _scan_out_list[idx]
else: else:
update_map[sit_sot_shared[abs(pos)-1]] = _scan_out_list[idx][-1] # Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan. Its
# absolute value (-1 in case is negative) represents an index
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None] scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论