提交 7987b518 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed bug in identification of sigmoid

It used to think that (any_constant + exp(x)) was equal to (1 + exp(x)).
上级 177f1296
......@@ -148,8 +148,10 @@ opt.register_stabilize(log1msigm_to_softplus, name = 'log1msigm_to_softplus')
opt.register_stabilize(log1pexp_to_softplus, name = 'log1pexp_to_softplus')
def is_1pexp(t):
# if t is of form (1+exp(x)), return x
# else return None
"""
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
"""
if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs)
......@@ -157,6 +159,12 @@ def is_1pexp(t):
if len(nonconsts) == 1:
maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
# Verify that the constant terms sum to 1.
if scalars:
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if numpy.allclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
return None
......
import unittest
from itertools import imap
import numpy
......@@ -8,7 +9,7 @@ from theano import config
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor
from theano.tensor.nnet.sigm import (
compute_mul, parse_mul_tree, perform_sigm_times_exp,
compute_mul, is_1pexp, parse_mul_tree, perform_sigm_times_exp,
register_local_1msigmoid, simplify_mul)
......@@ -235,3 +236,17 @@ class T_sigmoid_utils(unittest.TestCase):
assert parse_mul_tree(-x) == [True, x]
assert parse_mul_tree((x * y) * -z) == [
False, [[False, [[False, x], [False, y]]], [True, z]]]
def test_is_1pexp(self):
x = tensor.vector('x')
exp = tensor.exp
assert is_1pexp(1 + exp(x)) == (False, x)
assert is_1pexp(exp(x) + 1) == (False, x)
for neg, exp_arg in imap(is_1pexp, [(1 + exp(-x)), (exp(-x) + 1)]):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x)) is None
assert is_1pexp(2 + exp(x)) is None
assert is_1pexp(exp(x) + 2) is None
assert is_1pexp(exp(x) - 1) is None
assert is_1pexp(-1 + exp(x)) is None
assert is_1pexp(1 + 2 * exp(x)) is None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论