提交 9b8bf5b3 authored 作者: James Bergstra's avatar James Bergstra

changes to merge optimization to encourage stability

上级 b18ce6b3
...@@ -187,9 +187,15 @@ class MergeOptimizer(Optimizer): ...@@ -187,9 +187,15 @@ class MergeOptimizer(Optimizer):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def apply_constant_merge(self, env): def apply_constant_merge(self, env):
seen_constants = set()
const_sig = _metadict() # result -> result.signature() (for constants) const_sig = _metadict() # result -> result.signature() (for constants)
const_sig_inv = _metadict() # signature -> result (for constants) const_sig_inv = _metadict() # signature -> result (for constants)
for i, c in enumerate([r for r in env.results if isinstance(r, graph.Constant)]): for node in _list_of_nodes(env):
for i, c in enumerate([r for r in node.inputs if isinstance(r, graph.Constant)]):
if id(c) in seen_constants:
continue
else:
seen_constants.add(id(c))
sig = c.signature() sig = c.signature()
other_c = const_sig_inv.get(sig, None) other_c = const_sig_inv.get(sig, None)
if other_c is not None: if other_c is not None:
...@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer): ...@@ -242,19 +248,20 @@ class MergeOptimizer(Optimizer):
# we clear the dicts because the Constants signatures are not necessarily hashable # we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer like the other Results # and it's more efficient to give them an integer like the other Results
nodes_seen = set() nodes_seen = {}
for node in _list_of_nodes(env): for node_idx, node in enumerate(_list_of_nodes(env)):
# #
# these asserts ensure that the env has set the clients field properly the clients # these asserts ensure that the env has set the clients field properly the clients
# should at least contain `node` itself! # should at least contain `node` itself!
# #
assert len(node.inputs[0].clients) > 0 assert len(node.inputs[0].clients) > 0
assert (node,0) in node.inputs[0].clients assert (node,0) in node.inputs[0].clients
merge_candidates = [c for (c,i) in node.inputs[0].clients if c in nodes_seen] merge_candidates = [(nodes_seen[c],c) for (c,i) in node.inputs[0].clients if c in nodes_seen]
nodes_seen.add(node) merge_candidates.sort()
nodes_seen[node] = node_idx
#print 'NODE', node, merge_candidates, node.inputs[0].clients #print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates: for candidate_idx, candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
continue continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs)) inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
...@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -626,8 +633,8 @@ class NavigatorOptimizer(Optimizer):
def warn(exc, nav, repl_pairs, local_opt): def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback """failure_callback for NavigatorOptimizer: print traceback
""" """
print "WARNING: Optimization failure due to: ", local_opt print >> sys.stderr, "WARNING: Optimization failure due to: ", local_opt
print "TRACEBACK:" print >> sys.stderr, "TRACEBACK:"
traceback.print_exc() traceback.print_exc()
@staticmethod @staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt): def warn_inplace(exc, nav, repl_pairs, local_opt):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论