提交 ab3bce3f authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Frederic

Combined register_canonicalize_tags and register_canonicalize.

上级 b94282cb
...@@ -310,15 +310,14 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75, ...@@ -310,15 +310,14 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ if type(lopt) == str:
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags) def register(lopt):
return lopt return register_canonicalize(lopt, *tags, **kwargs)
return register
else:
def register_canonicalize_tags(*tags, **kwargs): name = (kwargs and kwargs.pop('name')) or lopt.__name__
def register(lopt): compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return register_canonicalize(lopt, *tags, **kwargs) return lopt
return register
def register_stabilize(lopt, *tags, **kwargs): def register_stabilize(lopt, *tags, **kwargs):
...@@ -1310,7 +1309,7 @@ def local_track_shape_i(node): ...@@ -1310,7 +1309,7 @@ def local_track_shape_i(node):
@register_specialize @register_specialize
@register_canonicalize_tags('gpu') @register_canonicalize('gpu')
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like: # replace all subtensor(make_vector) like:
...@@ -1360,7 +1359,7 @@ def local_subtensor_make_vector(node): ...@@ -1360,7 +1359,7 @@ def local_subtensor_make_vector(node):
#TODO: the other optimization for and, or, xor, le and ge see ticket #496. #TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_canonicalize_tags('fast_compile') @register_canonicalize('fast_compile')
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_elemwise(node): def local_useless_elemwise(node):
...@@ -3513,7 +3512,7 @@ def local_reduce_join(node): ...@@ -3513,7 +3512,7 @@ def local_reduce_join(node):
#else the reduction do something about the dtype. #else the reduction do something about the dtype.
@register_canonicalize_tags('fast_compile') @register_canonicalize('fast_compile')
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_cut_useless_reduce(node): def local_cut_useless_reduce(node):
"""Sum(a, axis=[]) -> a """ """Sum(a, axis=[]) -> a """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论