提交 2d2f2979 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor push_out_non_seq_scan and remove unnecessary cloning

上级 d47ce12b
......@@ -203,12 +203,11 @@ def push_out_non_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
# this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
node_inputs, node_outputs = node.op.inputs, node.op.outputs
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
to_remove_set = set()
to_replace_set = set()
......@@ -221,18 +220,21 @@ def push_out_non_seq_scan(fgraph, node):
add_to_replace.n = 0
# The variables that will replace the variables pushed-out of the
# inner-graph
replace_with_in = []
# The variables that have been pushed-out of the graph
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs = op.inner_non_seqs(node_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs = op.inner_seqs(node_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
......@@ -242,55 +244,60 @@ def push_out_non_seq_scan(fgraph, node):
if ( # we haven't already looked at this node
nd not in to_remove_set
and all(
[
(
(x in inner_non_seqs_set)
or (x.owner in to_remove_set)
or isinstance(x, Constant)
)
for x in nd.inputs
]
)
and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op, aesara.compile.ViewOp)
# We can (supposedly) do this because the assumption is that a
# `ViewOp` or `DeepCopyOp` will be just at the end of the
# function and not somewhere in the middle
and not isinstance(nd.op, aesara.compile.ViewOp)
and not isinstance(nd.op, aesara.compile.DeepCopyOp)
):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
# We have a candidate node to remove from the inner-graph
# Step 1. Reconstruct the node using the relevant outer-inputs.
#
# More specifically, the node's current inputs are either
# a) inner-graph input place-holders for non-sequences,
# b) the outputs of other nodes being pushed out of the inner-graph,
# c) or constants.
to_remove_set.add(nd)
outside_ins = []
for x in nd.inputs:
if x in inner_non_seqs_set:
_idx = inner_non_seqs_map[x]
outside_ins.append(outer_non_seqs[_idx])
elif x in to_replace_set:
outside_ins.append(replace_with_out[to_replace_map[x]])
elif isinstance(x, Constant):
outside_ins.append(x.clone())
new_inputs = []
for old_input in nd.inputs:
if old_input in inner_non_seqs_set:
# This is case a), so we want to use the corresponding
# outer-graph input as the input to our new pushed-out node
_idx = inner_non_seqs_map[old_input]
new_input = outer_non_seqs[_idx]
elif old_input in to_replace_set:
# This is case b), so we want to use the new pushed-out node
# as the input to this new pushed-out node
new_input = replace_with_out[to_replace_map[old_input]]
else:
# TODO: Explain why is this an error, and raise an
# appropriate exception type.
raise RuntimeError()
outside_ins = [
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
]
assert isinstance(old_input, Constant)
new_input = old_input
nw_outer_node = nd.op.make_node(*outside_ins)
new_input = old_input.type.filter_variable(new_input)
new_inputs.append(new_input)
pushed_out_node = nd.op.make_node(*new_inputs)
if config.compute_test_value != "off":
compute_test_value(nw_outer_node)
compute_test_value(pushed_out_node)
# Step 2. Create variables for replacements
# Step 2. Create variables to replace the old outputs of the node
# that we're pushing out of the inner-graph
for idx, y in enumerate(nd.outputs):
y_place_holder = safe_new(y, "_replace")
y_place_holder = y.clone()
# y_place_holder = safe_new(y, "_replace")
add_to_replace(y)
replace_with_in.append(y_place_holder)
assert isinstance(y, type(nw_outer_node.outputs[idx]))
replace_with_out.append(nw_outer_node.outputs[idx])
assert isinstance(y, type(pushed_out_node.outputs[idx]))
replace_with_out.append(pushed_out_node.outputs[idx])
# We need to check all candidate replacements and choose those that
# make sense for us
......@@ -326,14 +333,14 @@ def push_out_non_seq_scan(fgraph, node):
clean_to_replace, clean_replace_with_in, clean_replace_with_out
):
if isinstance(repl_out, Constant):
repl_in = repl_out.clone()
repl_in = repl_out
else:
nw_inner.append(repl_in)
nw_outer.append(repl_out)
givens[to_repl] = repl_in
op_outs = clone_replace(clean_outputs, replace=givens)
op_ins = clean_inputs + nw_inner
op_outs = clone_replace(node_outputs, replace=givens)
op_ins = node_inputs + nw_inner
# Reconstruct node
nwScan = Scan(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论