提交 a9472ea7 authored 作者: nouiz's avatar nouiz

Merge pull request #873 from goodfeli/stabilize_log_softmax

added an optimization to stabilize log softmax
......@@ -1729,3 +1729,23 @@ class Prepend_scalar_to_each_row(gof.Op):
prepend_scalar_to_each_row = Prepend_scalar_to_each_row()
prepend_0_to_each_row = Prepend_scalar_constant_to_each_row(0.)
prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.)
#numerically stabilize log softmax (X)
# as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x)
def make_out_pattern(X):
stabilized_X = X - X.max(axis=1).dimshuffle(0,'x')
out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(axis=1)).dimshuffle(0,'x')
#tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan
return out_var
local_log_softmax = gof.PatternSub( in_pattern = (tensor.log, (softmax, 'x')),
out_pattern = (make_out_pattern, 'x'),
allow_multiple_clients=True)
#don't do register_stabilize, this is to make local_log_softmax run
#only after another more specific optimization that stabilizes cross entropy
#opt.register_stabilize(local_log_softmax, name = 'local_log_softmax')
opt.register_specialize(local_log_softmax, name = 'local_log_softmax')
......@@ -132,6 +132,25 @@ class test_dimshuffle_lift(unittest.TestCase):
"{x,0,1}(y)), z)]"), str(g))
def test_stabilize_log_softmax():
mode = theano.compile.mode.get_default_mode()
mode = mode.including('local_log_softmax')
x = matrix()
y = theano.tensor.nnet.softmax(x)
z = theano.tensor.log(y)
f = function([x],z)
#check that the softmax has been optimized out
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
#call the function so debug mode can verify the optimized
#version matches the unoptimized version
rng = numpy.random.RandomState([2012,8,22])
f(numpy.cast[config.floatX](rng.randn(2,3)))
def test_add_canonizer_problem0():
n_segments = 10
label = lscalar('label')
......@@ -3781,4 +3800,4 @@ if __name__ == '__main__':
# unittest.main()
test_fusion().tes_memory_leak()
"""
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论