提交 56a46551 authored 作者: Frederic's avatar Frederic

Speed up optimization from 96s to 0.37s

fix gh-3412
上级 88fa607c
...@@ -452,8 +452,9 @@ def register_canonicalize(lopt, *tags, **kwargs): ...@@ -452,8 +452,9 @@ def register_canonicalize(lopt, *tags, **kwargs):
return register_canonicalize(inner_lopt, lopt, *tags, **kwargs) return register_canonicalize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = kwargs.pop('name', None) or lopt.__name__
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags) compile.optdb['canonicalize'].register(name, lopt, 'fast_run',
*tags, **kwargs)
return lopt return lopt
...@@ -463,8 +464,9 @@ def register_stabilize(lopt, *tags, **kwargs): ...@@ -463,8 +464,9 @@ def register_stabilize(lopt, *tags, **kwargs):
return register_stabilize(inner_lopt, lopt, *tags, **kwargs) return register_stabilize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = kwargs.pop('name', None) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags) compile.optdb['stabilize'].register(name, lopt, 'fast_run',
*tags, **kwargs)
return lopt return lopt
...@@ -474,9 +476,9 @@ def register_specialize(lopt, *tags, **kwargs): ...@@ -474,9 +476,9 @@ def register_specialize(lopt, *tags, **kwargs):
return register_specialize(inner_lopt, lopt, *tags, **kwargs) return register_specialize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = kwargs.pop('name', None) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', compile.optdb['specialize'].register(name, lopt, 'fast_run',
*tags) *tags, **kwargs)
return lopt return lopt
...@@ -4018,7 +4020,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4018,7 +4020,9 @@ class Canonizer(gof.LocalOptimizer):
""" """
if isinstance(v, Variable): if isinstance(v, Variable):
try: try:
return get_scalar_constant_value(v) # As the constant folding is in the canonicalize phase,
# We don't need to check all the graph each time.
return get_scalar_constant_value(v, only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return None return None
else: else:
...@@ -5467,9 +5471,6 @@ def local_greedy_distributor(node): ...@@ -5467,9 +5471,6 @@ def local_greedy_distributor(node):
return [rval] return [rval]
@register_canonicalize('fast_compile')
@register_stabilize('fast_compile')
@register_specialize('fast_compile')
@gof.local_optimizer(None) @gof.local_optimizer(None)
def constant_folding(node): def constant_folding(node):
for input in node.inputs: for input in node.inputs:
...@@ -5519,6 +5520,13 @@ def constant_folding(node): ...@@ -5519,6 +5520,13 @@ def constant_folding(node):
return rval return rval
topo_constant_folding=in2out(constant_folding, ignore_newtrees=False,
name="topo_constant_folding")
register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
def _is_1(expr): def _is_1(expr):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论