提交 2bbfe6dd authored 作者: AdeB's avatar AdeB

Fix dimshuffle problems in log_sum_exp. More tests.

上级 8eaa12f2
......@@ -6095,7 +6095,10 @@ def local_log_sum_exp(node):
sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, T.DimShuffle):
dimshuffle_op = sum_node.op
sum_node = sum_node.inputs[0].owner
else:
dimshuffle_op = None
if not sum_node or not isinstance(sum_node.op, T.Sum):
return
......@@ -6107,14 +6110,15 @@ def local_log_sum_exp(node):
return
pre_exp = exp_node.inputs[0]
max_pre_keepdims = T.max(pre_exp, axis=axis, keepdims=True)
max_pre_exp = T.max(pre_exp, axis=axis)
max_pre_exp_keepdims = T.makeKeepDims(pre_exp, max_pre_exp, axis)
ret = (max_pre_keepdims + T.log(T.sum(T.exp(pre_exp - max_pre_keepdims),
axis=axis, keepdims=True)))
ret = (max_pre_exp +
T.log(T.sum(T.exp(pre_exp - max_pre_exp_keepdims), axis=axis)))
# Restore shape and broadcastable pattern
ret = T.reshape(ret, node.inputs[0].shape)
ret = T.patternbroadcast(ret, node.inputs[0].broadcastable)
# Restore the dimshuffle op, if any.
if dimshuffle_op:
ret = dimshuffle_op(ret)
return [ret]
......
......@@ -50,7 +50,7 @@ from theano import tensor
from theano import tensor as T
from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar
from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix, tensor3
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import (
AdvancedSubtensor,
......@@ -6656,14 +6656,20 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, T.Alloc)
def test_local_log_sum_exp1():
# Tests if optimization is applied
x = matrix('x')
y = T.log(T.sum(T.exp(x), axis=1))
def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None):
sum_exp = T.sum(T.exp(x), axis=axis)
if dimshuffle_op:
sum_exp = dimshuffle_op(sum_exp)
y = T.log(sum_exp)
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE)
return function([x], y, mode=MODE)
def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
f = compile_graph_log_sum_exp(x, axis, dimshuffle_op)
for node in f.maker.fgraph.toposort():
fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if (hasattr(node.op, 'scalar_op') and
node.op.scalar_op == theano.scalar.basic.maximum):
return
......@@ -6676,27 +6682,48 @@ def test_local_log_sum_exp1():
raise Exception('No maximum detected after log_sum_exp optimisation')
def test_local_log_sum_exp1():
# Tests if optimization is applied by checking the presence of the maximum
x = tensor3('x')
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
x = TensorType(dtype='floatX', broadcastable=(False, True, False))('x')
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)
def test_local_log_sum_exp2():
# Tests if the optimization works (result is correct) around 1.0
x = matrix('x')
y = T.log(T.sum(T.exp(x), axis=1))
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE)
x_val = 1.0 + numpy.random.rand(4, 3).astype(config.floatX)/10.0
x = tensor3('x')
x_val = 1.0 + numpy.random.rand(4, 3, 2).astype(config.floatX) / 10.0
f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1))
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1).T)
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
def test_local_log_sum_exp3():
# Tests if the optimization works (result is correct) for extreme value 100
x = vector('x')
y = T.log(T.sum(T.exp(x), axis=0))
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE)
f = compile_graph_log_sum_exp(x, axis=0)
x_val = numpy.array([-100., 100.]).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论