提交 2302fc97 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added a bunch of extra comments

上级 962e8884
...@@ -75,9 +75,10 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -75,9 +75,10 @@ def remove_constants_and_unused_inputs_scan(node):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and
node.inputs[idx+1].tag.unique_value is not None): node.inputs[idx+1].tag.unique_value is not None):
try: try:
val = tensor.get_constant_value(node.inputs[idx+1], # This works if input is a constant that has all entries
return_ndarray = True) # equal
givens[op_ins[idx]] = tensor.constant(val[0]) val = tensor.get_constant_value(node.inputs[idx+1])
givens[op_ins[idx]] = node.inputs[idx+1].clone()[0]
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
...@@ -99,6 +100,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -99,6 +100,7 @@ def remove_constants_and_unused_inputs_scan(node):
op_outs = scan_utils.clone(op_outs, replace = givens) op_outs = scan_utils.clone(op_outs, replace = givens)
nw_info = op.info.copy() nw_info = op.info.copy()
nw_info['n_seqs'] = nw_n_seqs nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs nw_outs = nwScan.make_node(*nw_outer).outputs
return nw_outs return nw_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论