提交 f3640b4b authored 作者: Frederic Bastien's avatar Frederic Bastien

added test.

上级 2f495a60
...@@ -86,12 +86,36 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -86,12 +86,36 @@ class T_sigmoid_opts(unittest.TestCase):
class T_softplus_opts(unittest.TestCase): class T_softplus_opts(unittest.TestCase):
def setUp(self): def setUp(self):
if theano.config.mode == 'FAST_COMPILE':
m = theano.compile.mode.get_mode('FAST_RUN')
else:
m = theano.compile.mode.get_default_mode().excluding('local_elemwise_fusion')
self.m = m
utt.seed_rng() utt.seed_rng()
# def test_logsigm_to_softplus(self): def test_logsigm_to_softplus(self):
# pass x = T.vector()
out = T.log(sigmoid(x))
f = theano.function([x],out,mode=self.m)
topo = f.maker.env.toposort()
print topo
assert len(topo)==3
assert isinstance(topo[0].op.scalar_op, theano.scalar.Neg)
assert isinstance(topo[1].op.scalar_op, theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[2].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54))
def test_log1msigm_to_softplus(self):
x = T.vector()
out = T.log(1-sigmoid(x))
f = theano.function([x],out,mode=self.m)
topo = f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op.scalar_op, theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[1].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54))
# def test_log1msigm_to_softplus(self):
# pass
def test_log1pexp_to_softplus(self): def test_log1pexp_to_softplus(self):
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
...@@ -100,7 +124,8 @@ class T_softplus_opts(unittest.TestCase): ...@@ -100,7 +124,8 @@ class T_softplus_opts(unittest.TestCase):
x = T.vector() x = T.vector()
out = T.log(1+T.exp(x)) out = T.log(1+T.exp(x))
f = theano.function([x],out) f = theano.function([x],out,mode=self.m)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
assert len(topo)==1 assert len(topo)==1
assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus) assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus)
f(numpy.random.rand(54))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论