提交 6641b3a4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add comments.

上级 0f2196fd
...@@ -45,6 +45,15 @@ def info(*msg): ...@@ -45,6 +45,15 @@ def info(*msg):
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def remove_constants_and_unused_inputs_scan(node): def remove_constants_and_unused_inputs_scan(node):
'''
Move constants into the inner graph, and remove unused inputs.
Constants that are in the outer graph are represented by a free symbolic
variable in the inner graph. If we move them into the inner graph,
constant-folding can happen in the inner graph.
This is applied only on sequences and non-sequences,
not on initial states.
'''
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
op = node.op op = node.op
...@@ -54,9 +63,12 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -54,9 +63,12 @@ def remove_constants_and_unused_inputs_scan(node):
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ])) op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot st += op.n_sit_sot
st += op.n_shared_outs st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs, op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)
'')
# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
out_stuff_inner = op_ins[op.n_seqs:st] out_stuff_inner = op_ins[op.n_seqs:st]
non_seqs = op_ins[st:] non_seqs = op_ins[st:]
st = ( op.n_seqs + st = ( op.n_seqs +
op.n_mit_mot + op.n_mit_mot +
...@@ -67,9 +79,13 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -67,9 +79,13 @@ def remove_constants_and_unused_inputs_scan(node):
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1+op.n_seqs:st] out_stuff_outer = node.inputs[1+op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph
givens = {} givens = {}
# All the inputs of the inner graph of the new scan
nw_inner = [] nw_inner = []
# Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]] nw_outer = [node.inputs[0]]
all_ins = gof.graph.inputs(op_outs) all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and if (isinstance(node.inputs[idx+1], tensor.TensorConstant) and
...@@ -84,6 +100,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -84,6 +100,7 @@ def remove_constants_and_unused_inputs_scan(node):
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
nw_inner += [op_ins[idx]] nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx+1]] nw_outer += [node.inputs[idx+1]]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
nw_inner += out_stuff_inner nw_inner += out_stuff_inner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论