提交 2bbfe6dd authored 作者: AdeB's avatar AdeB

Fix dimshuffle problems in log_sum_exp. More tests.

上级 8eaa12f2
...@@ -6095,7 +6095,10 @@ def local_log_sum_exp(node): ...@@ -6095,7 +6095,10 @@ def local_log_sum_exp(node):
sum_node = node.inputs[0].owner sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle # If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, T.DimShuffle): if sum_node and isinstance(sum_node.op, T.DimShuffle):
dimshuffle_op = sum_node.op
sum_node = sum_node.inputs[0].owner sum_node = sum_node.inputs[0].owner
else:
dimshuffle_op = None
if not sum_node or not isinstance(sum_node.op, T.Sum): if not sum_node or not isinstance(sum_node.op, T.Sum):
return return
...@@ -6107,14 +6110,15 @@ def local_log_sum_exp(node): ...@@ -6107,14 +6110,15 @@ def local_log_sum_exp(node):
return return
pre_exp = exp_node.inputs[0] pre_exp = exp_node.inputs[0]
max_pre_keepdims = T.max(pre_exp, axis=axis, keepdims=True) max_pre_exp = T.max(pre_exp, axis=axis)
max_pre_exp_keepdims = T.makeKeepDims(pre_exp, max_pre_exp, axis)
ret = (max_pre_keepdims + T.log(T.sum(T.exp(pre_exp - max_pre_keepdims), ret = (max_pre_exp +
axis=axis, keepdims=True))) T.log(T.sum(T.exp(pre_exp - max_pre_exp_keepdims), axis=axis)))
# Restore shape and broadcastable pattern # Restore the dimshuffle op, if any.
ret = T.reshape(ret, node.inputs[0].shape) if dimshuffle_op:
ret = T.patternbroadcast(ret, node.inputs[0].broadcastable) ret = dimshuffle_op(ret)
return [ret] return [ret]
......
...@@ -50,7 +50,7 @@ from theano import tensor ...@@ -50,7 +50,7 @@ from theano import tensor
from theano import tensor as T from theano import tensor as T
from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar from theano.tensor import scalar, iscalar, lscalar, fscalar, dscalar
from theano.tensor import vector, ivector, lvector, fvector, dvector from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix, tensor3
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor, AdvancedSubtensor,
...@@ -6656,14 +6656,20 @@ def test_local_useless_alloc(): ...@@ -6656,14 +6656,20 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, T.Alloc) assert isinstance(topo[-1].op, T.Alloc)
def test_local_log_sum_exp1(): def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None):
# Tests if optimization is applied sum_exp = T.sum(T.exp(x), axis=axis)
x = matrix('x') if dimshuffle_op:
y = T.log(T.sum(T.exp(x), axis=1)) sum_exp = dimshuffle_op(sum_exp)
y = T.log(sum_exp)
MODE = theano.compile.get_default_mode().including('local_log_sum_exp') MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE) return function([x], y, mode=MODE)
def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
f = compile_graph_log_sum_exp(x, axis, dimshuffle_op)
for node in f.maker.fgraph.toposort(): fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if (hasattr(node.op, 'scalar_op') and if (hasattr(node.op, 'scalar_op') and
node.op.scalar_op == theano.scalar.basic.maximum): node.op.scalar_op == theano.scalar.basic.maximum):
return return
...@@ -6676,27 +6682,48 @@ def test_local_log_sum_exp1(): ...@@ -6676,27 +6682,48 @@ def test_local_log_sum_exp1():
raise Exception('No maximum detected after log_sum_exp optimisation') raise Exception('No maximum detected after log_sum_exp optimisation')
def test_local_log_sum_exp1():
# Tests if optimization is applied by checking the presence of the maximum
x = tensor3('x')
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
x = TensorType(dtype='floatX', broadcastable=(False, True, False))('x')
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)
def test_local_log_sum_exp2(): def test_local_log_sum_exp2():
# Tests if the optimization works (result is correct) around 1.0 # Tests if the optimization works (result is correct) around 1.0
x = matrix('x')
y = T.log(T.sum(T.exp(x), axis=1))
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE)
x_val = 1.0 + numpy.random.rand(4, 3).astype(config.floatX)/10.0 x = tensor3('x')
x_val = 1.0 + numpy.random.rand(4, 3, 2).astype(config.floatX) / 10.0
f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1)) naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1))
optimised_ret = f(x_val) optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1).T)
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret) assert numpy.allclose(naive_ret, optimised_ret)
def test_local_log_sum_exp3(): def test_local_log_sum_exp3():
# Tests if the optimization works (result is correct) for extreme value 100 # Tests if the optimization works (result is correct) for extreme value 100
x = vector('x') x = vector('x')
y = T.log(T.sum(T.exp(x), axis=0)) f = compile_graph_log_sum_exp(x, axis=0)
MODE = theano.compile.get_default_mode().including('local_log_sum_exp')
f = function([x], y, mode=MODE)
x_val = numpy.array([-100., 100.]).astype(config.floatX) x_val = numpy.array([-100., 100.]).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论