提交 61099c6f authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Add a useless opt pass. Need to time.

上级 9f759f07
......@@ -150,6 +150,16 @@ optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge')
# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
# New nodes that don't have infer_shape need that the original node
# also don't have infer_shape
optdb.register('useless', gof.EquilibriumDB(ignore_newtrees=False),
0.6, 'fast_run', 'fast_compile')
optdb.register('merge1.1', gof.MergeOptimizer(),
0.65, 'fast_run', 'fast_compile', 'merge')
# rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(ignore_newtrees=False),
1, 'fast_run', 'fast_compile', 'canonicalize_db')
......
......@@ -409,6 +409,18 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
'fast_run', 'inplace')
def register_useless(lopt, *tags, **kwargs):
if type(lopt) == str:
def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs)
return register
else:
name = kwargs.pop('name', None) or lopt.__name__
compile.optdb['useless'].register(name, lopt, 'fast_run',
*tags, **kwargs)
return lopt
def register_canonicalize(lopt, *tags, **kwargs):
if type(lopt) == str:
def register(inner_lopt):
......@@ -1756,6 +1768,7 @@ compile.optdb.register('local_elemwise_alloc',
@register_canonicalize("fast_compile")
@register_useless("fast_compile")
@gof.local_optimizer([T.fill])
def local_useless_fill(node):
"""fill(s,v) -> v
......@@ -1776,6 +1789,7 @@ def local_useless_fill(node):
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@gof.local_optimizer([T.alloc])
def local_useless_alloc(node):
"""
......@@ -1925,6 +1939,7 @@ def local_subtensor_remove_broadcastable_index(node):
@register_specialize
@register_canonicalize('fast_compile_gpu')
#@register_useless
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node):
"""
......@@ -2009,6 +2024,7 @@ def local_subtensor_make_vector(node):
# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_useless('fast_compile')
@register_canonicalize('fast_compile')
@register_specialize
@gof.local_optimizer([T.Elemwise])
......@@ -2428,6 +2444,7 @@ def local_upcast_elemwise_constant_inputs(node):
##################
@register_useless
@register_canonicalize
@register_specialize
@gof.local_optimizer([IncSubtensor])
......@@ -2518,6 +2535,7 @@ def local_set_to_inc_subtensor(node):
return [ret]
@register_useless
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor])
......@@ -2547,6 +2565,7 @@ def local_useless_slice(node):
return [out]
@register_useless
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
......@@ -3373,6 +3392,7 @@ def local_adv_sub1_adv_inc_sub1(node):
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@gof.local_optimizer([IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1])
......@@ -3484,6 +3504,7 @@ def local_useless_inc_subtensor_alloc(node):
# Rebroadcast opts #
####################
@register_useless
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Rebroadcast])
......@@ -3611,6 +3632,7 @@ def apply_rebroadcast_opt(rval):
#############
@register_specialize
@register_canonicalize
@register_useless
@gof.local_optimizer([T.Join])
def local_join_1(node):
"""Join(i, x) => x
......@@ -3627,6 +3649,7 @@ def local_join_1(node):
return [tensors[0]]
# TODO: merge in local_useless_join
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Join])
......@@ -3683,6 +3706,7 @@ def local_join_empty(node):
@register_specialize
@register_canonicalize
@register_useless
@gof.local_optimizer([T.Join])
def local_join_make_vector(node):
"""Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
......@@ -3785,6 +3809,7 @@ def local_expm1(node):
###############
# Switch opts #
###############
@register_useless('fast_compile')
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize
@gof.local_optimizer([T.Elemwise])
......@@ -4053,6 +4078,7 @@ def local_merge_switch_same_cond(node):
#############
# Tile Opts #
#############
@register_useless
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Tile])
......@@ -4099,6 +4125,7 @@ def local_useless_tile(node):
##############
# Split Opts #
##############
@register_useless
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Split])
......@@ -4179,6 +4206,7 @@ register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain')
@register_useless
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Reshape])
......@@ -4987,6 +5015,7 @@ def local_elemwise_sub_zeros(node):
return [T.zeros_like(node.inputs[0])]
@register_useless
@register_specialize
@register_stabilize
@register_canonicalize
......@@ -5435,9 +5464,10 @@ def local_reduce_join(node):
return [ret]
@register_canonicalize('fast_compile')
@register_canonicalize('fast_compile', 'local_cut_useless_reduce')
@register_useless
@gof.local_optimizer(ALL_REDUCE)
def local_cut_useless_reduce(node):
def local_useless_reduce(node):
"""Sum(a, axis=[]) -> a """
if isinstance(node.op, T.CAReduce):
summed, = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论