提交 d039a055 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary cloning from push_out_seq_scan

上级 2d2f2979
......@@ -410,12 +410,11 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
# this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
node_inputs, node_outputs = node.op.inputs, node.op.outputs
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)}
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
to_remove_set = set()
to_replace_set = set()
......@@ -433,12 +432,12 @@ def push_out_seq_scan(fgraph, node):
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs = op.inner_non_seqs(node_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs = op.inner_seqs(node_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)}
......@@ -467,26 +466,18 @@ def push_out_seq_scan(fgraph, node):
for x in nd.inputs:
if x in inner_non_seqs_set:
_idx = inner_non_seqs_map[x]
outside_ins.append(outer_non_seqs[_idx])
new_input = outer_non_seqs[_idx]
elif x in inner_seqs_set:
outside_ins.append(outer_seqs[inner_seqs_map[x]])
new_input = outer_seqs[inner_seqs_map[x]]
depends_on_seqs = True
elif x in to_replace_set:
outside_ins.append(replace_with_out[to_replace_map[x]])
new_input = replace_with_out[to_replace_map[x]]
depends_on_seqs = True
elif isinstance(x, Constant):
outside_ins.append(x.clone())
else:
raise Exception(
(
"Error in the `scan_pushout_seq_"
"operations`. The optimization tries "
"to move some computation from scan "
"which is not allowed to move. Report "
"this on aesara-users list"
),
x,
)
assert isinstance(x, Constant)
new_input = x
outside_ins.append(new_input)
if not depends_on_seqs:
# Removing this node from the inner graph of scan
......@@ -580,15 +571,15 @@ def push_out_seq_scan(fgraph, node):
clean_to_replace, clean_replace_with_in, clean_replace_with_out
):
if isinstance(repl_out, Constant):
repl_in = repl_out.clone()
repl_in = repl_out
else:
nw_inner.append(repl_in)
nw_outer.append(repl_out)
givens[to_repl] = repl_in
op_outs = clone_replace(clean_outputs, replace=givens)
op_ins = nw_inner + clean_inputs
op_outs = clone_replace(node_outputs, replace=givens)
op_ins = nw_inner + node_inputs
# Reconstruct node
nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner))
......@@ -621,7 +612,7 @@ def push_out_seq_scan(fgraph, node):
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx]
ls = clean_outputs
ls = node_outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node.inputs)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论