提交 808da4a5 authored 作者: Caglar's avatar Caglar

removed to_keep and changed the reconstruction of to_keep_set

removed the profile decorator.
上级 5816665f
...@@ -249,14 +249,14 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -249,14 +249,14 @@ class PushOutNonSeqScan(gof.Optimizer):
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_add = to_replace_set.add
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0 nto_replace = 0
def add_to_replace(y, nto_replace): def add_to_replace(y):
to_replace_add(y) to_replace_set.add(y)
to_replace_map[y] = nto_replace to_replace_map[y] = add_to_replace.n
return nto_replace + 1 add_to_replace.n +=1
add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
replace_with_out = [] replace_with_out = []
...@@ -317,7 +317,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -317,7 +317,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
nto_replace = add_to_replace(y, nto_replace) add_to_replace(y)
replace_with_in.append(y_place_holder) replace_with_in.append(y_place_holder)
assert isinstance(y, type(nw_outer_node.outputs[idx])) assert isinstance(y, type(nw_outer_node.outputs[idx]))
replace_with_out.append(nw_outer_node.outputs[idx]) replace_with_out.append(nw_outer_node.outputs[idx])
...@@ -332,8 +332,10 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -332,8 +332,10 @@ class PushOutNonSeqScan(gof.Optimizer):
existent_nodes = [nd for nd in local_fgraph_topo existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove_set] if nd not in to_remove_set]
existent_nodes_set = set(existent_nodes) existent_nodes_set = set(existent_nodes)
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep) to_keep_set = set([])
for nd in existent_nodes:
to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (# If types are different, conversion Op will be inserted, if (# If types are different, conversion Op will be inserted,
...@@ -377,7 +379,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -377,7 +379,7 @@ class PushOutNonSeqScan(gof.Optimizer):
remove=[node], remove=[node],
reason='scanOp_pushout_nonseqs_ops') reason='scanOp_pushout_nonseqs_ops')
return True return True
elif not to_keep: elif not to_keep_set:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
...@@ -441,14 +443,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -441,14 +443,14 @@ class PushOutSeqScan(gof.Optimizer):
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
to_replace_add = to_replace_set.add
to_replace_map = OrderedDict() to_replace_map = OrderedDict()
nto_replace = 0 nto_replace = 0
def add_to_replace(y, nto_replace): def add_to_replace(y):
to_replace_add(y) to_replace_set.add(y)
to_replace_map[y] = nto_replace to_replace_map[y] = add_to_replace.n
return nto_replace + 1 add_to_replace.n += 1
add_to_replace.n = 0
replace_with_in = [] replace_with_in = []
replace_with_out = [] replace_with_out = []
...@@ -516,7 +518,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -516,7 +518,7 @@ class PushOutSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
nto_replace = add_to_replace(y, nto_replace) add_to_replace(y)
replace_with_in.append(y_place_holder) replace_with_in.append(y_place_holder)
replace_with_out.append(nw_outer_node.outputs[idx]) replace_with_out.append(nw_outer_node.outputs[idx])
...@@ -540,7 +542,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -540,7 +542,7 @@ class PushOutSeqScan(gof.Optimizer):
new_outer = outside_ins.dimshuffle(new_ord) new_outer = outside_ins.dimshuffle(new_ord)
y = nd.outputs[0] y = nd.outputs[0]
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
nto_replace = add_to_replace(y, nto_replace) add_to_replace(y)
replace_with_in.append(y_place_holder) replace_with_in.append(y_place_holder)
replace_with_out.append(new_outer) replace_with_out.append(new_outer)
...@@ -561,8 +563,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -561,8 +563,10 @@ class PushOutSeqScan(gof.Optimizer):
existent_nodes = [nd for nd in local_fgraph_topo existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove_set] if nd not in to_remove_set]
existent_nodes_set = set(existent_nodes) existent_nodes_set = set(existent_nodes)
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep) to_keep_set = set([])
for nd in existent_nodes:
to_keep_set.update(nd.inputs)
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set
...@@ -608,7 +612,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -608,7 +612,7 @@ class PushOutSeqScan(gof.Optimizer):
remove=[node], remove=[node],
reason='scanOp_pushout_seqs_ops') reason='scanOp_pushout_seqs_ops')
return True return True
elif (not to_keep and elif (not to_keep_set and
not op.as_while and not op.as_while and
not op.outer_mitmot(node)): not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论