提交 f40a8a25 authored 作者: nouiz's avatar nouiz

Merge pull request #233 from delallea/sigm_opt_fix

Fixed optimization for exp(x) * sigmoid(-x)
差异被折叠。
...@@ -7,7 +7,9 @@ from theano import tensor as T ...@@ -7,7 +7,9 @@ from theano import tensor as T
from theano import config from theano import config
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor
from theano.tensor.nnet.sigm import register_local_1msigmoid from theano.tensor.nnet.sigm import (
compute_mul, parse_mul_tree, perform_sigm_times_exp,
register_local_1msigmoid, simplify_mul)
class T_sigmoid(unittest.TestCase): class T_sigmoid(unittest.TestCase):
...@@ -23,12 +25,29 @@ class T_softplus(unittest.TestCase): ...@@ -23,12 +25,29 @@ class T_softplus(unittest.TestCase):
utt.verify_grad(softplus, [numpy.random.rand(3,4)]) utt.verify_grad(softplus, [numpy.random.rand(3,4)])
class T_sigmoid_opts(unittest.TestCase): class T_sigmoid_opts(unittest.TestCase):
def test_exp_over_1_plus_exp(self):
def get_mode(self, excluding=[]):
"""
Return appropriate mode for the tests.
:param excluding: List of optimizations to exclude.
:return: The current default mode unless the `config.mode` option is
set to 'FAST_COMPILE' (in which case it is replaced by the 'FAST_RUN'
mode), without the optimizations specified in `excluding`.
"""
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
m = 'FAST_RUN' mode = theano.compile.mode.get_mode('FAST_RUN')
m = theano.compile.mode.get_mode(m) else:
m = m.excluding('local_elemwise_fusion') mode = theano.compile.mode.get_default_mode()
if excluding:
return mode.excluding(*excluding)
else:
return mode
def test_exp_over_1_plus_exp(self):
m = self.get_mode(excluding=['local_elemwise_fusion'])
x = T.dvector() x = T.dvector()
...@@ -60,10 +79,7 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -60,10 +79,7 @@ class T_sigmoid_opts(unittest.TestCase):
if not register_local_1msigmoid: if not register_local_1msigmoid:
return return
m = theano.config.mode m = self.get_mode()
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
x = T.fmatrix() x = T.fmatrix()
# tests exp_over_1_plus_exp # tests exp_over_1_plus_exp
...@@ -77,6 +93,80 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -77,6 +93,80 @@ class T_sigmoid_opts(unittest.TestCase):
assert [node.op for node in f.maker.env.toposort()] == [tensor.neg, assert [node.op for node in f.maker.env.toposort()] == [tensor.neg,
sigmoid_inplace] sigmoid_inplace]
def test_local_sigm_times_exp(self):
"""
Test the `local_sigm_times_exp` optimization.
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
def match(func, ops):
#print [node.op.scalar_op for node in func.maker.env.toposort()]
assert [node.op for node in func.maker.env.toposort()] == ops
m = self.get_mode(excluding=['local_elemwise_fusion', 'inplace'])
x, y = tensor.vectors('x', 'y')
f = theano.function([x], sigmoid(-x) * tensor.exp(x), mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid])
f = theano.function([x], sigmoid(x) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid])
f = theano.function([x], -(-(-(sigmoid(x)))) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid, tensor.neg])
f = theano.function(
[x, y],
(sigmoid(x) * sigmoid(-y) * -tensor.exp(-x) * tensor.exp(x * y) *
tensor.exp(y)),
mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid,
tensor.mul, tensor.neg])
def test_perform_sigm_times_exp(self):
"""
Test the core function doing the `sigm_times_exp` optimization.
It is easier to test different graph scenarios this way than by
compiling a theano function.
"""
x, y, z, t = tensor.vectors('x', 'y', 'z', 't')
exp = tensor.exp
def ok(expr1, expr2):
trees = [parse_mul_tree(e) for e in (expr1, expr2)]
perform_sigm_times_exp(trees[0])
trees[0] = simplify_mul(trees[0])
# TODO Ideally we would do a full comparison without `str`. However
# the implementation of `__eq__` in Variables is not currently
# appropriate for this. So for now we use this limited technique,
# but it could be improved on.
good = str(trees[0]) == str(trees[1])
if not good:
print trees[0]
print trees[1]
print '***'
theano.printing.debugprint(compute_mul(trees[0]))
print '***'
theano.printing.debugprint(compute_mul(trees[1]))
assert good
ok(sigmoid(x) * exp(-x), sigmoid(-x))
ok(-x * sigmoid(x) * (y * (-1 * z) * exp(-x)),
-x * sigmoid(-x) * (y * (-1 * z)))
ok(-sigmoid(-x) *
(exp(y) * (-exp(-z) * 3 * -exp(x)) *
(y * 2 * (-sigmoid(-y) * (z + t) * exp(z)) * sigmoid(z))) *
-sigmoid(x),
sigmoid(x) *
(-sigmoid(y) * (-sigmoid(-z) * 3) * (y * 2 * ((z + t) * exp(z)))) *
-sigmoid(x))
ok(exp(-x) * -exp(-x) * (-sigmoid(x) * -sigmoid(x)),
-sigmoid(-x) * sigmoid(-x))
ok(-exp(x) * -sigmoid(-x) * -exp(-x),
-sigmoid(-x))
class T_softplus_opts(unittest.TestCase): class T_softplus_opts(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -123,3 +213,28 @@ class T_softplus_opts(unittest.TestCase): ...@@ -123,3 +213,28 @@ class T_softplus_opts(unittest.TestCase):
assert len(topo)==1 assert len(topo)==1
assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus) assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus)
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54).astype(config.floatX))
class T_sigmoid_utils(unittest.TestCase):
"""
Test utility functions found in 'sigm.py'.
"""
def test_compute_mul(self):
x, y, z = tensor.vectors('x', 'y', 'z')
tree = (x * y) * -z
mul_tree = parse_mul_tree(tree)
# Note that we do not test the reverse identity, i.e.
# compute_mul(parse_mul_tree(tree)) == tree
# because Theano currently lacks an easy way to compare variables.
assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
def test_parse_mul_tree(self):
x, y, z = tensor.vectors('x', 'y', 'z')
assert parse_mul_tree(x * y) == [False, [[False, x], [False, y]]]
assert parse_mul_tree(-(x * y)) == [True, [[False, x], [False, y]]]
assert parse_mul_tree(-x * y) == [False, [[True, x], [False, y]]]
assert parse_mul_tree(-x) == [True, x]
assert parse_mul_tree((x * y) * -z) == [
False, [[False, [[False, x], [False, y]]], [True, z]]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论