提交 2302fc97 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added a bunch of extra comments

上级 962e8884
......@@ -75,9 +75,10 @@ def remove_constants_and_unused_inputs_scan(node):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and
node.inputs[idx+1].tag.unique_value is not None):
try:
val = tensor.get_constant_value(node.inputs[idx+1],
return_ndarray = True)
givens[op_ins[idx]] = tensor.constant(val[0])
# This works if input is a constant that has all entries
# equal
val = tensor.get_constant_value(node.inputs[idx+1])
givens[op_ins[idx]] = node.inputs[idx+1].clone()[0]
except TypeError:
pass
elif op_ins[idx] in all_ins:
......@@ -99,6 +100,7 @@ def remove_constants_and_unused_inputs_scan(node):
op_outs = scan_utils.clone(op_outs, replace = givens)
nw_info = op.info.copy()
nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs
return nw_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论