提交 61099c6f authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Add a useless opt pass. Need to time.

上级 9f759f07
...@@ -150,6 +150,16 @@ optdb = gof.SequenceDB() ...@@ -150,6 +150,16 @@ optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge') 0, 'fast_run', 'fast_compile', 'merge')
# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
# New nodes that don't have infer_shape need that the original node
# also don't have infer_shape
optdb.register('useless', gof.EquilibriumDB(ignore_newtrees=False),
0.6, 'fast_run', 'fast_compile')
optdb.register('merge1.1', gof.MergeOptimizer(),
0.65, 'fast_run', 'fast_compile', 'merge')
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False), optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False),
1, 'fast_run', 'fast_compile', 'canonicalize_db') 1, 'fast_run', 'fast_compile', 'canonicalize_db')
......
...@@ -409,6 +409,18 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75, ...@@ -409,6 +409,18 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
'fast_run', 'inplace') 'fast_run', 'inplace')
def register_useless(lopt, *tags, **kwargs):
if type(lopt) == str:
def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs)
return register
else:
name = kwargs.pop('name', None) or lopt.__name__
compile.optdb['useless'].register(name, lopt, 'fast_run',
*tags, **kwargs)
return lopt
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
if type(lopt) == str: if type(lopt) == str:
def register(inner_lopt): def register(inner_lopt):
...@@ -1756,6 +1768,7 @@ compile.optdb.register('local_elemwise_alloc', ...@@ -1756,6 +1768,7 @@ compile.optdb.register('local_elemwise_alloc',
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
@register_useless("fast_compile")
@gof.local_optimizer([T.fill]) @gof.local_optimizer([T.fill])
def local_useless_fill(node): def local_useless_fill(node):
"""fill(s,v) -> v """fill(s,v) -> v
...@@ -1776,6 +1789,7 @@ def local_useless_fill(node): ...@@ -1776,6 +1789,7 @@ def local_useless_fill(node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless
@gof.local_optimizer([T.alloc]) @gof.local_optimizer([T.alloc])
def local_useless_alloc(node): def local_useless_alloc(node):
""" """
...@@ -1925,6 +1939,7 @@ def local_subtensor_remove_broadcastable_index(node): ...@@ -1925,6 +1939,7 @@ def local_subtensor_remove_broadcastable_index(node):
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
#@register_useless
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
""" """
...@@ -2009,6 +2024,7 @@ def local_subtensor_make_vector(node): ...@@ -2009,6 +2024,7 @@ def local_subtensor_make_vector(node):
# TODO: the other optimization for and, or, xor, le and ge see ticket #496. # TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_useless('fast_compile')
@register_canonicalize('fast_compile') @register_canonicalize('fast_compile')
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
...@@ -2428,6 +2444,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -2428,6 +2444,7 @@ def local_upcast_elemwise_constant_inputs(node):
################## ##################
@register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([IncSubtensor]) @gof.local_optimizer([IncSubtensor])
...@@ -2518,6 +2535,7 @@ def local_set_to_inc_subtensor(node): ...@@ -2518,6 +2535,7 @@ def local_set_to_inc_subtensor(node):
return [ret] return [ret]
@register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
...@@ -2547,6 +2565,7 @@ def local_useless_slice(node): ...@@ -2547,6 +2565,7 @@ def local_useless_slice(node):
return [out] return [out]
@register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
...@@ -3373,6 +3392,7 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -3373,6 +3392,7 @@ def local_adv_sub1_adv_inc_sub1(node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless
@gof.local_optimizer([IncSubtensor, @gof.local_optimizer([IncSubtensor,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1]) AdvancedIncSubtensor1])
...@@ -3484,6 +3504,7 @@ def local_useless_inc_subtensor_alloc(node): ...@@ -3484,6 +3504,7 @@ def local_useless_inc_subtensor_alloc(node):
# Rebroadcast opts # # Rebroadcast opts #
#################### ####################
@register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Rebroadcast]) @gof.local_optimizer([T.Rebroadcast])
...@@ -3611,6 +3632,7 @@ def apply_rebroadcast_opt(rval): ...@@ -3611,6 +3632,7 @@ def apply_rebroadcast_opt(rval):
############# #############
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless
@gof.local_optimizer([T.Join]) @gof.local_optimizer([T.Join])
def local_join_1(node): def local_join_1(node):
"""Join(i, x) => x """Join(i, x) => x
...@@ -3627,6 +3649,7 @@ def local_join_1(node): ...@@ -3627,6 +3649,7 @@ def local_join_1(node):
return [tensors[0]] return [tensors[0]]
# TODO: merge in local_useless_join
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Join]) @gof.local_optimizer([T.Join])
...@@ -3683,6 +3706,7 @@ def local_join_empty(node): ...@@ -3683,6 +3706,7 @@ def local_join_empty(node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless
@gof.local_optimizer([T.Join]) @gof.local_optimizer([T.Join])
def local_join_make_vector(node): def local_join_make_vector(node):
"""Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) """Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
...@@ -3785,6 +3809,7 @@ def local_expm1(node): ...@@ -3785,6 +3809,7 @@ def local_expm1(node):
############### ###############
# Switch opts # # Switch opts #
############### ###############
@register_useless('fast_compile')
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond') @register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
...@@ -4053,6 +4078,7 @@ def local_merge_switch_same_cond(node): ...@@ -4053,6 +4078,7 @@ def local_merge_switch_same_cond(node):
############# #############
# Tile Opts # # Tile Opts #
############# #############
@register_useless
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([T.Tile]) @gof.local_optimizer([T.Tile])
...@@ -4099,6 +4125,7 @@ def local_useless_tile(node): ...@@ -4099,6 +4125,7 @@ def local_useless_tile(node):
############## ##############
# Split Opts # # Split Opts #
############## ##############
@register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Split]) @gof.local_optimizer([T.Split])
...@@ -4179,6 +4206,7 @@ register_canonicalize(local_reshape_chain(T.Reshape), ...@@ -4179,6 +4206,7 @@ register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain') name='local_reshape_chain')
@register_useless
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([T.Reshape]) @gof.local_optimizer([T.Reshape])
...@@ -4987,6 +5015,7 @@ def local_elemwise_sub_zeros(node): ...@@ -4987,6 +5015,7 @@ def local_elemwise_sub_zeros(node):
return [T.zeros_like(node.inputs[0])] return [T.zeros_like(node.inputs[0])]
@register_useless
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
...@@ -5435,9 +5464,10 @@ def local_reduce_join(node): ...@@ -5435,9 +5464,10 @@ def local_reduce_join(node):
return [ret] return [ret]
@register_canonicalize('fast_compile') @register_canonicalize('fast_compile', 'local_cut_useless_reduce')
@register_useless
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_cut_useless_reduce(node): def local_useless_reduce(node):
"""Sum(a, axis=[]) -> a """ """Sum(a, axis=[]) -> a """
if isinstance(node.op, T.CAReduce): if isinstance(node.op, T.CAReduce):
summed, = node.inputs summed, = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论