提交 1d8302e5 authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Frederic

Reworked register_stabilize to work like register_canonicalize and register_specialize.

上级 f018f34a
......@@ -1330,6 +1330,7 @@ class CrossentropyCategorical1Hot(gof.Op):
crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@opt.register_stabilize('gpu')
@opt.register_specialize('gpu')
@gof.optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
......@@ -1358,8 +1359,6 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
while search_make_one_sub():
pass
return
opt.register_stabilize(crossentropy_to_crossentropy_with_softmax_with_bias,
'gpu')
@gof.optimizer
......
......@@ -319,11 +319,16 @@ def register_canonicalize(lopt, *tags, **kwargs):
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_stabilize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags)
return lopt
if type(lopt) == str:
def register(inner_lopt):
return register_stabilize(inner_lopt, *tags, **kwargs)
return register
else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize(lopt, *tags, **kwargs):
......@@ -4161,6 +4166,8 @@ def attempt_distribution(factor, num, denum, out_type):
neg_pairs))), num, denum
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.mul, T.true_div, T.inv])
def local_greedy_distributor(node):
"""
......@@ -4225,10 +4232,10 @@ def local_greedy_distributor(node):
return [rval]
register_canonicalize(local_greedy_distributor)
register_stabilize(local_greedy_distributor)
@register_canonicalize('fast_compile')
@register_stabilize('fast_compile')
@register_specialize('fast_compile')
@gof.local_optimizer(None)
def constant_folding(node):
for input in node.inputs:
......@@ -4262,10 +4269,6 @@ def constant_folding(node):
rval.append(constant(output.type, storage_map[output][0]))
return rval
register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding, 'fast_compile')
register_specialize(constant_folding, 'fast_compile')
def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论