提交 2f495a60 authored 作者: Frederic Bastien's avatar Frederic Bastien

add an optimisation log1p(exp(x))-> softplus(x). This make computation more stable.

上级 3e84f926
......@@ -117,8 +117,15 @@ log1msigm_to_softplus = gof.PatternSub(
(tensor.neg, (softplus, 'x')),
allow_multiple_clients = True)
log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p,
(tensor.exp, 'x')),
(softplus, 'x'),
allow_multiple_clients = True)
opt.register_stabilize(logsigm_to_softplus, name = 'logsigm_to_softplus')
opt.register_stabilize(log1msigm_to_softplus, name = 'log1msigm_to_softplus')
opt.register_stabilize(log1pexp_to_softplus, name = 'log1pexp_to_softplus')
def is_1pexp(t):
# if t is of form (1+exp(x)), return x
......
......@@ -83,3 +83,24 @@ class T_sigmoid_opts(unittest.TestCase):
assert [node.op for node in f.maker.env.toposort()] == [tensor.neg,
sigmoid_inplace]
class T_softplus_opts(unittest.TestCase):
def setUp(self):
utt.seed_rng()
# def test_logsigm_to_softplus(self):
# pass
# def test_log1msigm_to_softplus(self):
# pass
def test_log1pexp_to_softplus(self):
m = theano.config.mode
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
x = T.vector()
out = T.log(1+T.exp(x))
f = theano.function([x],out)
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论