提交 f018f34a authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Frederic

Reworked opt.register_specialize to match the way opt.register_canonicalize works.

上级 f34caad6
......@@ -577,6 +577,7 @@ class Softmax(gof.Op):
softmax = Softmax()
@opt.register_specialize('gpu')
@gof.local_optimizer([softmax])
def local_softmax_with_bias(node):
"""Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias)
......@@ -635,7 +636,6 @@ def local_softmax_with_bias(node):
#This condition is not always true. See the test
#nnet/tests/test_nnet.py:T_SoftmaxWithBias.test_broadcast
return [sm_bias]
opt.register_specialize(local_softmax_with_bias, 'gpu')
def softmax_simplifier(numerators, denominators):
......@@ -1330,6 +1330,7 @@ class CrossentropyCategorical1Hot(gof.Op):
crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@opt.register_specialize('gpu')
@gof.optimizer
def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
"""This is a stabilization optimization
......@@ -1359,8 +1360,6 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
return
opt.register_stabilize(crossentropy_to_crossentropy_with_softmax_with_bias,
'gpu')
opt.register_specialize(crossentropy_to_crossentropy_with_softmax_with_bias,
'gpu')
@gof.optimizer
......@@ -1409,6 +1408,7 @@ optdb.register('crossentropy_to_crossentropy_with_softmax',
'fast_run', 'xent', 'gpu')
@opt.register_specialize('gpu')
@gof.local_optimizer([softmax_grad])
def local_crossentropy_to_crossentropy_with_softmax_grad(node):
if node.op == softmax_grad:
......@@ -1419,10 +1419,9 @@ def local_crossentropy_to_crossentropy_with_softmax_grad(node):
dx = crossentropy_softmax_1hot_with_bias_dx(g_nll,
coding_dist, true_one_of_n)
return [dx]
opt.register_specialize(local_crossentropy_to_crossentropy_with_softmax_grad,
'gpu')
@opt.register_specialize('gpu')
@gof.local_optimizer([tensor._max_and_argmax])
def local_argmax_pushdown(node):
if node.op == tensor._max_and_argmax and node.inputs[0].owner and \
......@@ -1453,7 +1452,6 @@ def local_argmax_pushdown(node):
tensor.DimShuffle(
pre_bias.broadcastable,
('x', 0))(pre_bias), axis)
opt.register_specialize(local_argmax_pushdown, 'gpu')
# Utility function used by the two next optimizations
......@@ -1509,6 +1507,7 @@ def _is_const(z, val, approx=False):
return numpy.all(maybe == val)
@opt.register_specialize('gpu')
@gof.local_optimizer([subtensor.AdvancedSubtensor, tensor.log])
def local_advanced_indexing_crossentropy_onehot(node):
log = None
......@@ -1547,9 +1546,9 @@ def local_advanced_indexing_crossentropy_onehot(node):
return [-crossentropy_softmax_argmax_1hot_with_bias(x_var,
b_var,
labels)[0]]
opt.register_specialize(local_advanced_indexing_crossentropy_onehot, 'gpu')
@opt.register_specialize('gpu')
@gof.local_optimizer([softmax_grad])
def local_advanced_indexing_crossentropy_onehot_grad(node):
if not (node.op == softmax_grad):
......@@ -1769,11 +1768,10 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if labels.ndim == 1 and x_var.ndim == 2:
return [crossentropy_softmax_1hot_with_bias_dx(out_grad, sm, labels)]
else:
return
opt.register_specialize(local_advanced_indexing_crossentropy_onehot_grad,
'gpu')
retur
@opt.register_specialize('gpu')
@gof.local_optimizer([softmax_with_bias])
def graph_merge_softmax_with_crossentropy_softmax(node):
if node.op == softmax_with_bias:
......@@ -1785,8 +1783,6 @@ def graph_merge_softmax_with_crossentropy_softmax(node):
xx, bb, ll = big_client.inputs
mergeable_client = big_client.op(x, b, ll)
return [mergeable_client[1]]
opt.register_specialize(graph_merge_softmax_with_crossentropy_softmax,
'gpu')
def binary_crossentropy(output, target):
......
......@@ -327,6 +327,11 @@ def register_stabilize(lopt, *tags, **kwargs):
def register_specialize(lopt, *tags, **kwargs):
if type(lopt) == str:
def register(inner_lopt):
return register_specialize(inner_lopt, *tags, **kwargs)
return register
else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags)
return lopt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论