提交 0845a3d6 authored 作者: nouiz's avatar nouiz

Merge pull request #464 from lamblin/fix_eq_opt_error

Fix max'd error in EquilibriumOptimizer with constant folding
...@@ -1113,6 +1113,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1113,6 +1113,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = False max_use_abort = False
opt_name = None opt_name = None
process_count = {} process_count = {}
max_nb_nodes = 0
while changed and not max_use_abort: while changed and not max_use_abort:
changed = False changed = False
...@@ -1130,7 +1131,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1130,7 +1131,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
q = deque(graph.io_toposort(env.inputs, start_from)) q = deque(graph.io_toposort(env.inputs, start_from))
max_use = len(q) * self.max_use_ratio max_nb_nodes = max(max_nb_nodes, len(q))
max_use = max_nb_nodes * self.max_use_ratio
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
...@@ -1148,17 +1150,16 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1148,17 +1150,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
current_node = node current_node = node
for lopt in self.local_optimizers: for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0) process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use: lopt_change = self.process_node(env, node, lopt)
max_use_abort = True if lopt_change:
opt_name = (getattr(lopt, "name", None) process_count[lopt] += 1
or getattr(lopt, "__name__", None) or "") changed = True
else: if process_count[lopt] > max_use:
lopt_change = self.process_node(env, node, lopt) max_use_abort = True
if lopt_change: opt_name = (getattr(lopt, "name", None)
process_count[lopt] += 1 or getattr(lopt, "__name__", ""))
changed = True if node not in env.nodes:
if node not in env.nodes: break # go to next node
break # go to next node
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
......
...@@ -15,13 +15,9 @@ AddConfigVar('optdb.position_cutoff', ...@@ -15,13 +15,9 @@ AddConfigVar('optdb.position_cutoff',
' position of the optimizer where to stop.', ' position of the optimizer where to stop.',
FloatParam(numpy.inf), FloatParam(numpy.inf),
in_c_key=False) in_c_key=False)
#upgraded to 20 to avoid EquibriumOptimizer error
# to be max'ed out by constant folding (can
# I increase the max ratio only for
# constant folding somehow?
AddConfigVar('optdb.max_use_ratio', AddConfigVar('optdb.max_use_ratio',
'A ratio that prevent infinite loop in EquilibriumOptimizer.', 'A ratio that prevent infinite loop in EquilibriumOptimizer.',
FloatParam(20), FloatParam(5),
in_c_key=False) in_c_key=False)
......
...@@ -406,4 +406,4 @@ class TestEquilibrium(object): ...@@ -406,4 +406,4 @@ class TestEquilibrium(object):
finally: finally:
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
print 'after', g print 'after', g
assert str(g) == '[Op4(x, y)]' assert str(g) == '[Op1(x, y)]'
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论