提交 d039a055 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary cloning from push_out_seq_scan

上级 2d2f2979
...@@ -410,12 +410,11 @@ def push_out_seq_scan(fgraph, node): ...@@ -410,12 +410,11 @@ def push_out_seq_scan(fgraph, node):
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
return False return False
# this flag tells if there was any change during the last iterations node_inputs, node_outputs = node.op.inputs, node.op.outputs
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
local_fgraph_topo = io_toposort(clean_inputs, clean_outputs) local_fgraph_topo = io_toposort(node_inputs, node_outputs)
local_fgraph_outs_set = set(clean_outputs) local_fgraph_outs_set = set(node_outputs)
local_fgraph_outs_map = {v: k for k, v in enumerate(clean_outputs)} local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)}
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -433,12 +432,12 @@ def push_out_seq_scan(fgraph, node): ...@@ -433,12 +432,12 @@ def push_out_seq_scan(fgraph, node):
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(node_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(node_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)} inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)}
...@@ -467,26 +466,18 @@ def push_out_seq_scan(fgraph, node): ...@@ -467,26 +466,18 @@ def push_out_seq_scan(fgraph, node):
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs_set: if x in inner_non_seqs_set:
_idx = inner_non_seqs_map[x] _idx = inner_non_seqs_map[x]
outside_ins.append(outer_non_seqs[_idx]) new_input = outer_non_seqs[_idx]
elif x in inner_seqs_set: elif x in inner_seqs_set:
outside_ins.append(outer_seqs[inner_seqs_map[x]]) new_input = outer_seqs[inner_seqs_map[x]]
depends_on_seqs = True depends_on_seqs = True
elif x in to_replace_set: elif x in to_replace_set:
outside_ins.append(replace_with_out[to_replace_map[x]]) new_input = replace_with_out[to_replace_map[x]]
depends_on_seqs = True depends_on_seqs = True
elif isinstance(x, Constant):
outside_ins.append(x.clone())
else: else:
raise Exception( assert isinstance(x, Constant)
( new_input = x
"Error in the `scan_pushout_seq_"
"operations`. The optimization tries " outside_ins.append(new_input)
"to move some computation from scan "
"which is not allowed to move. Report "
"this on aesara-users list"
),
x,
)
if not depends_on_seqs: if not depends_on_seqs:
# Removing this node from the inner graph of scan # Removing this node from the inner graph of scan
...@@ -580,15 +571,15 @@ def push_out_seq_scan(fgraph, node): ...@@ -580,15 +571,15 @@ def push_out_seq_scan(fgraph, node):
clean_to_replace, clean_replace_with_in, clean_replace_with_out clean_to_replace, clean_replace_with_in, clean_replace_with_out
): ):
if isinstance(repl_out, Constant): if isinstance(repl_out, Constant):
repl_in = repl_out.clone() repl_in = repl_out
else: else:
nw_inner.append(repl_in) nw_inner.append(repl_in)
nw_outer.append(repl_out) nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
op_outs = clone_replace(clean_outputs, replace=givens) op_outs = clone_replace(node_outputs, replace=givens)
op_ins = nw_inner + clean_inputs op_ins = nw_inner + node_inputs
# Reconstruct node # Reconstruct node
nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner)) nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner))
...@@ -621,7 +612,7 @@ def push_out_seq_scan(fgraph, node): ...@@ -621,7 +612,7 @@ def push_out_seq_scan(fgraph, node):
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
ls = clean_outputs ls = node_outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node.inputs)[odx] inp = op.outer_mitsot(node.inputs)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论