提交 83d5709a authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #4736 from adbrebs/logSumExp

Log sum exp optimization for numerical stability
...@@ -6121,6 +6121,49 @@ def local_log_add(node): ...@@ -6121,6 +6121,49 @@ def local_log_add(node):
return [ret] return [ret]
@gof.local_optimizer([T.log])
def local_log_sum_exp(node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if node.op != T.log:
return
sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, T.DimShuffle):
dimshuffle_op = sum_node.op
sum_node = sum_node.inputs[0].owner
else:
dimshuffle_op = None
if not sum_node or not isinstance(sum_node.op, T.Sum):
return
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node or not (
isinstance(exp_node.op, Elemwise) and
isinstance(exp_node.op.scalar_op, scalar.Exp)):
return
pre_exp = exp_node.inputs[0]
max_pre_exp = T.max(pre_exp, axis=axis)
max_pre_exp_keepdims = T.makeKeepDims(pre_exp, max_pre_exp, axis)
ret = (max_pre_exp +
T.log(T.sum(T.exp(pre_exp - max_pre_exp_keepdims), axis=axis)))
# Restore the dimshuffle op, if any.
if dimshuffle_op:
ret = dimshuffle_op(ret)
return [ret]
compile.optdb.register('local_log_sum_exp',
in2out(local_log_sum_exp, ignore_newtrees=True),
1.6, 'fast_run')
def add_calculate(num, denum, aslist=False, out_type=None): def add_calculate(num, denum, aslist=False, out_type=None):
# TODO: make sure that this function and mul_calculate are similar # TODO: make sure that this function and mul_calculate are similar
if out_type is None: if out_type is None:
......
...@@ -50,7 +50,7 @@ from theano import tensor ...@@ -50,7 +50,7 @@ from theano import tensor
from theano import tensor as T from theano import tensor as T
from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar
from theano.tensor import vector, ivector, lvector, fvector, dvector from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix, tensor3
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor, AdvancedSubtensor,
...@@ -6649,6 +6649,81 @@ def test_local_useless_alloc(): ...@@ -6649,6 +6649,81 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, T.Alloc) assert isinstance(topo[-1].op, T.Alloc)
def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None):
sum_exp = T.sum(T.exp(x), axis=axis)
if dimshuffle_op:
sum_exp = dimshuffle_op(sum_exp)
y = T.log(sum_exp)
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
return function([x], y, mode=MODE)
def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
f = compile_graph_log_sum_exp(x, axis, dimshuffle_op)
fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if (hasattr(node.op, 'scalar_op') and
node.op.scalar_op == theano.scalar.basic.maximum):
return
# in mode FAST_COMPILE, the optimisations don't replace the
# MaxAndArgmax op.
if isinstance(node.op, theano.tensor.MaxAndArgmax):
return
raise Exception('No maximum detected after log_sum_exp optimisation')
def test_local_log_sum_exp1():
# Tests if optimization is applied by checking the presence of the maximum
x = tensor3('x')
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
x = TensorType(dtype='floatX', broadcastable=(False, True, False))('x')
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)
def test_local_log_sum_exp2():
# Tests if the optimization works (result is correct) around 1.0
x = tensor3('x')
x_val = 1.0 + numpy.random.rand(4, 3, 2).astype(config.floatX) / 10.0
f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1))
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1).T)
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
def test_local_log_sum_exp3():
# Tests if the optimization works (result is correct) for extreme value 100
x = vector('x')
f = compile_graph_log_sum_exp(x, axis=0)
x_val = numpy.array([-100., 100.]).astype(config.floatX)
optimised_ret = f(x_val)
assert numpy.allclose(optimised_ret, 100.)
if __name__ == '__main__': if __name__ == '__main__':
t = TestMakeVector('setUp') t = TestMakeVector('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论