提交 f4af10da authored 作者: Frederic's avatar Frederic

Use OrderedDict in scan op to try to fix schochastic opt order problem.

上级 813dde8a
...@@ -86,7 +86,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -86,7 +86,7 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_outer = node.inputs[1 + op.n_seqs:st] out_stuff_outer = node.inputs[1 + op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph # To replace constants in the outer graph by clones in the inner graph
givens = {} givens = OrderedDict()
# All the inputs of the inner graph of the new scan # All the inputs of the inner graph of the new scan
nw_inner = [] nw_inner = []
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
...@@ -257,7 +257,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -257,7 +257,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = {} givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in zip(clean_to_replace,
...@@ -284,7 +284,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -284,7 +284,7 @@ class PushOutNonSeqScan(gof.Optimizer):
return True return True
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = OrderedDict()
for idx, out in enumerate(to_replace): for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
...@@ -439,7 +439,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -439,7 +439,7 @@ class PushOutSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = {} givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in zip(clean_to_replace,
...@@ -529,7 +529,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -529,7 +529,7 @@ class ScanInplaceOptimizer(Optimizer):
for pos in xrange(n_outs): for pos in xrange(n_outs):
info = copy.deepcopy(op.info) info = copy.deepcopy(op.info)
if not 'destroy_map' in info: if not 'destroy_map' in info:
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']] info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
...@@ -600,7 +600,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -600,7 +600,7 @@ class ScanSaveMem(gof.Optimizer):
# Each access to shape_of is in a try..except block in order to # Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of # use a default version when the variable is not in the shape_of
# dictionary. # dictionary.
shape_of = {} shape_of = OrderedDict()
# 1. Initialization of variables # 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared # Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to # variables (those have no intermediate values) so it is safer to
...@@ -923,7 +923,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -923,7 +923,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {} inv_compress_map = OrderedDict()
for k, v in compress_map.items(): for k, v in compress_map.items():
inv_compress_map[v] = k inv_compress_map[v] = k
...@@ -1053,7 +1053,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1053,7 +1053,7 @@ class ScanMerge(gof.Optimizer):
else: else:
as_while = False as_while = False
info = {} info = OrderedDict()
info['tap_array'] = [] info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes]) info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes]) info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
...@@ -1228,7 +1228,7 @@ def has_duplicates(l): ...@@ -1228,7 +1228,7 @@ def has_duplicates(l):
def make_equiv(lo, li): def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on """builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs.""" the equivalence of their corresponding outer inputs."""
seeno = {} seeno = OrderedDict()
left = [] left = []
right = [] right = []
for o, i in zip(lo, li): for o, i in zip(lo, li):
...@@ -1248,7 +1248,7 @@ def scan_merge_inouts(node): ...@@ -1248,7 +1248,7 @@ def scan_merge_inouts(node):
a = scan_args(node.inputs, node.outputs, a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info) node.op.inputs, node.op.outputs, node.op.info)
inp_equiv = {} inp_equiv = OrderedDict()
if has_duplicates(a.outer_in_seqs): if has_duplicates(a.outer_in_seqs):
new_outer_seqs = [] new_outer_seqs = []
...@@ -1310,7 +1310,7 @@ def scan_merge_inouts(node): ...@@ -1310,7 +1310,7 @@ def scan_merge_inouts(node):
left += _left left += _left
right += _right right += _right
if has_duplicates(na.outer_in_mit_mot): if has_duplicates(na.outer_in_mit_mot):
seen = {} seen = OrderedDict()
for omm, imm, _sl in zip(na.outer_in_mit_mot, for omm, imm, _sl in zip(na.outer_in_mit_mot,
na.inner_in_mit_mot, na.mit_mot_in_slices): na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl) sl = tuple(_sl)
...@@ -1322,7 +1322,7 @@ def scan_merge_inouts(node): ...@@ -1322,7 +1322,7 @@ def scan_merge_inouts(node):
seen[(omm, sl)] = imm seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot): if has_duplicates(na.outer_in_mit_sot):
seen = {} seen = OrderedDict()
for oms, ims, _sl in zip(na.outer_in_mit_sot, for oms, ims, _sl in zip(na.outer_in_mit_sot,
na.inner_in_mit_sot, na.inner_in_mit_sot,
na.mit_sot_in_slices): na.mit_sot_in_slices):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论