提交 2d2f2979 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor push_out_non_seq_scan and remove unnecessary cloning

上级 d47ce12b
...@@ -203,12 +203,11 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -203,12 +203,11 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
return False return False
# this flag tells if there was any change during the last iterations node_inputs, node_outputs = node.op.inputs, node.op.outputs
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs) local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -221,18 +220,21 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -221,18 +220,21 @@ def push_out_non_seq_scan(fgraph, node):
add_to_replace.n = 0 add_to_replace.n = 0
# The variables that will replace the variables pushed-out of the
# inner-graph
replace_with_in = [] replace_with_in = []
# The variables that have been pushed-out of the graph
replace_with_out = [] replace_with_out = []
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(node_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(node_inputs)
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -242,55 +244,60 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -242,55 +244,60 @@ def push_out_non_seq_scan(fgraph, node):
if ( # we haven't already looked at this node if ( # we haven't already looked at this node
nd not in to_remove_set nd not in to_remove_set
and all( and all(
[ (
( (x in inner_non_seqs_set)
(x in inner_non_seqs_set) or (x.owner in to_remove_set)
or (x.owner in to_remove_set) or isinstance(x, Constant)
or isinstance(x, Constant) )
) for x in nd.inputs
for x in nd.inputs
]
) )
and # We can (supposedly) do this because the assumption is that a
# we can do this because the assumption is that a # `ViewOp` or `DeepCopyOp` will be just at the end of the
# viewOp or deepCopyOp will be just at the end of the # function and not somewhere in the middle
# function and not somewhere in the middle .. and not isinstance(nd.op, aesara.compile.ViewOp)
not isinstance(nd.op, aesara.compile.ViewOp)
and not isinstance(nd.op, aesara.compile.DeepCopyOp) and not isinstance(nd.op, aesara.compile.DeepCopyOp)
): ):
# We have a candidate node to remove from the inner-graph
# We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct the node using the relevant outer-inputs.
#
# More specifically, the node's current inputs are either
# a) inner-graph input place-holders for non-sequences,
# b) the outputs of other nodes being pushed out of the inner-graph,
# c) or constants.
to_remove_set.add(nd) to_remove_set.add(nd)
outside_ins = [] new_inputs = []
for x in nd.inputs: for old_input in nd.inputs:
if x in inner_non_seqs_set: if old_input in inner_non_seqs_set:
_idx = inner_non_seqs_map[x] # This is case a), so we want to use the corresponding
outside_ins.append(outer_non_seqs[_idx]) # outer-graph input as the input to our new pushed-out node
elif x in to_replace_set: _idx = inner_non_seqs_map[old_input]
outside_ins.append(replace_with_out[to_replace_map[x]]) new_input = outer_non_seqs[_idx]
elif isinstance(x, Constant): elif old_input in to_replace_set:
outside_ins.append(x.clone()) # This is case b), so we want to use the new pushed-out node
# as the input to this new pushed-out node
new_input = replace_with_out[to_replace_map[old_input]]
else: else:
# TODO: Explain why is this an error, and raise an assert isinstance(old_input, Constant)
# appropriate exception type. new_input = old_input
raise RuntimeError()
outside_ins = [
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
]
nw_outer_node = nd.op.make_node(*outside_ins) new_input = old_input.type.filter_variable(new_input)
new_inputs.append(new_input)
pushed_out_node = nd.op.make_node(*new_inputs)
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(nw_outer_node) compute_test_value(pushed_out_node)
# Step 2. Create variables for replacements # Step 2. Create variables to replace the old outputs of the node
# that we're pushing out of the inner-graph
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = safe_new(y, "_replace") y_place_holder = y.clone()
# y_place_holder = safe_new(y, "_replace")
add_to_replace(y) add_to_replace(y)
replace_with_in.append(y_place_holder) replace_with_in.append(y_place_holder)
assert isinstance(y, type(nw_outer_node.outputs[idx])) assert isinstance(y, type(pushed_out_node.outputs[idx]))
replace_with_out.append(nw_outer_node.outputs[idx]) replace_with_out.append(pushed_out_node.outputs[idx])
# We need to check all candidate replacements and choose those that # We need to check all candidate replacements and choose those that
# make sense for us # make sense for us
...@@ -326,14 +333,14 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -326,14 +333,14 @@ def push_out_non_seq_scan(fgraph, node):
clean_to_replace, clean_replace_with_in, clean_replace_with_out clean_to_replace, clean_replace_with_in, clean_replace_with_out
): ):
if isinstance(repl_out, Constant): if isinstance(repl_out, Constant):
repl_in = repl_out.clone() repl_in = repl_out
else: else:
nw_inner.append(repl_in) nw_inner.append(repl_in)
nw_outer.append(repl_out) nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
op_outs = clone_replace(clean_outputs, replace=givens) op_outs = clone_replace(node_outputs, replace=givens)
op_ins = clean_inputs + nw_inner op_ins = node_inputs + nw_inner
# Reconstruct node # Reconstruct node
nwScan = Scan( nwScan = Scan(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论