提交 377b1bea authored 作者: Caglar's avatar Caglar

replaced set update to set add.

上级 c05cf1c9
...@@ -230,12 +230,14 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -230,12 +230,14 @@ class PushOutNonSeqScan(gof.Optimizer):
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove_set = set({}) to_remove_set = set()
to_replace_set = set({}) to_remove_add = to_remove_set.add
to_replace_set = set()
to_replace_add = to_replace_set.add
to_replace_map = {} to_replace_map = {}
nto_replace = 0 nto_replace = 0
def add_to_replace(y, nto_replace): def add_to_replace(y, nto_replace):
to_replace_set.update([y]) to_replace_add(y)
to_replace_map[y] = nto_replace to_replace_map[y] = nto_replace
return nto_replace + 1 return nto_replace + 1
...@@ -274,7 +276,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -274,7 +276,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
to_remove_set.update([nd]) to_remove_add(nd)
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs_set: if x in inner_non_seqs_set:
...@@ -419,12 +421,15 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -419,12 +421,15 @@ class PushOutSeqScan(gof.Optimizer):
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove_set = set({}) to_remove_set = set()
to_replace_set = set({}) to_remove_add = to_remove_set.add
to_replace_set = set()
to_replace_add = to_replace_set.add
to_replace_map = {} to_replace_map = {}
nto_replace = 0 nto_replace = 0
def add_to_replace(y, nto_replace): def add_to_replace(y, nto_replace):
to_replace_set.update([y]) to_replace_add(y)
to_replace_map[y] = nto_replace to_replace_map[y] = nto_replace
return nto_replace + 1 return nto_replace + 1
...@@ -440,7 +445,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -440,7 +445,7 @@ class PushOutSeqScan(gof.Optimizer):
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict({v:k for k, v in enumerate(inner_seqs)}) inner_seqs_map = dict({v:k for k,v in enumerate(inner_seqs)})
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -458,7 +463,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -458,7 +463,7 @@ class PushOutSeqScan(gof.Optimizer):
(x in inner_seqs_set) (x in inner_seqs_set)
for x in nd.inputs]) and for x in nd.inputs]) and
not nd in to_remove_set): not nd in to_remove_set):
to_remove_set.update([nd]) to_remove_add(nd)
outside_ins = [] outside_ins = []
depends_on_seqs = False depends_on_seqs = False
...@@ -508,7 +513,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -508,7 +513,7 @@ class PushOutSeqScan(gof.Optimizer):
(nd.inputs[0] in inner_seqs_set or (nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove_set) and nd.inputs[0].owner in to_remove_set) and
not nd in to_remove_set): not nd in to_remove_set):
to_remove_set.update([nd]) to_remove_add(nd)
x = nd.inputs[0] x = nd.inputs[0]
if x in inner_seqs_set: if x in inner_seqs_set:
outside_ins = outer_seqs[inner_seqs_map[x]] outside_ins = outer_seqs[inner_seqs_map[x]]
...@@ -996,6 +1001,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -996,6 +1001,7 @@ class ScanInplaceOptimizer(Optimizer):
info = copy.deepcopy(op.info) info = copy.deepcopy(op.info)
if not 'destroy_map' in info: if not 'destroy_map' in info:
info['destroy_map'] = OrderedDict() info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']] info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论