提交 1d8302e5 authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Frederic

Reworked register_stabilize to work like register_canonicalize and register_specialize.

上级 f018f34a
...@@ -1330,6 +1330,7 @@ class CrossentropyCategorical1Hot(gof.Op): ...@@ -1330,6 +1330,7 @@ class CrossentropyCategorical1Hot(gof.Op):
crossentropy_categorical_1hot = CrossentropyCategorical1Hot() crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@opt.register_stabilize('gpu')
@opt.register_specialize('gpu') @opt.register_specialize('gpu')
@gof.optimizer @gof.optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
...@@ -1358,8 +1359,6 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): ...@@ -1358,8 +1359,6 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
while search_make_one_sub(): while search_make_one_sub():
pass pass
return return
opt.register_stabilize(crossentropy_to_crossentropy_with_softmax_with_bias,
'gpu')
@gof.optimizer @gof.optimizer
......
...@@ -319,11 +319,16 @@ def register_canonicalize(lopt, *tags, **kwargs): ...@@ -319,11 +319,16 @@ def register_canonicalize(lopt, *tags, **kwargs):
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags) compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt return lopt
def register_stabilize(lopt, *tags, **kwargs): def register_stabilize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ if type(lopt) == str:
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags) def register(inner_lopt):
return lopt return register_stabilize(inner_lopt, *tags, **kwargs)
return register
else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
...@@ -4161,6 +4166,8 @@ def attempt_distribution(factor, num, denum, out_type): ...@@ -4161,6 +4166,8 @@ def attempt_distribution(factor, num, denum, out_type):
neg_pairs))), num, denum neg_pairs))), num, denum
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.mul, T.true_div, T.inv]) @gof.local_optimizer([T.mul, T.true_div, T.inv])
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
...@@ -4225,10 +4232,10 @@ def local_greedy_distributor(node): ...@@ -4225,10 +4232,10 @@ def local_greedy_distributor(node):
return [rval] return [rval]
register_canonicalize(local_greedy_distributor)
register_stabilize(local_greedy_distributor)
@register_canonicalize('fast_compile')
@register_stabilize('fast_compile')
@register_specialize('fast_compile')
@gof.local_optimizer(None) @gof.local_optimizer(None)
def constant_folding(node): def constant_folding(node):
for input in node.inputs: for input in node.inputs:
...@@ -4262,10 +4269,6 @@ def constant_folding(node): ...@@ -4262,10 +4269,6 @@ def constant_folding(node):
rval.append(constant(output.type, storage_map[output][0])) rval.append(constant(output.type, storage_map[output][0]))
return rval return rval
register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding, 'fast_compile')
register_specialize(constant_folding, 'fast_compile')
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """rtype bool. True iff expr is a constant close to 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论