提交 0845a3d6 authored 作者: nouiz's avatar nouiz

Merge pull request #464 from lamblin/fix_eq_opt_error

Fix max'd error in EquilibriumOptimizer with constant folding
......@@ -1113,6 +1113,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = False
opt_name = None
process_count = {}
max_nb_nodes = 0
while changed and not max_use_abort:
changed = False
......@@ -1130,7 +1131,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
q = deque(graph.io_toposort(env.inputs, start_from))
max_use = len(q) * self.max_use_ratio
max_nb_nodes = max(max_nb_nodes, len(q))
max_use = max_nb_nodes * self.max_use_ratio
def importer(node):
if node is not current_node:
q.append(node)
......@@ -1148,17 +1150,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
current_node = node
for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
if process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", None) or "")
else:
lopt_change = self.process_node(env, node, lopt)
if lopt_change:
process_count[lopt] += 1
changed = True
if node not in env.nodes:
break # go to next node
lopt_change = self.process_node(env, node, lopt)
if lopt_change:
process_count[lopt] += 1
changed = True
if process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
if node not in env.nodes:
break # go to next node
finally:
self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
......
......@@ -15,13 +15,9 @@ AddConfigVar('optdb.position_cutoff',
' position of the optimizer where to stop.',
FloatParam(numpy.inf),
in_c_key=False)
#upgraded to 20 to avoid EquibriumOptimizer error
# to be max'ed out by constant folding (can
# I increase the max ratio only for
# constant folding somehow?
AddConfigVar('optdb.max_use_ratio',
'A ratio that prevent infinite loop in EquilibriumOptimizer.',
FloatParam(20),
FloatParam(5),
in_c_key=False)
......
......@@ -406,4 +406,4 @@ class TestEquilibrium(object):
finally:
_logger.setLevel(oldlevel)
print 'after', g
assert str(g) == '[Op4(x, y)]'
assert str(g) == '[Op1(x, y)]'
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论