提交 8ab2661b authored 作者: Frederic's avatar Frederic

pep8

上级 b8fa2b75
...@@ -8,24 +8,27 @@ import numpy.random ...@@ -8,24 +8,27 @@ import numpy.random
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
''' '''
Different tests that are not connected to any particular Op, or functionality of Different tests that are not connected to any particular Op, or
Theano. Here will go for example code that we will publish in papers, that we functionality of Theano. Here will go for example code that we will
should ensure that it will remain operational publish in papers, that we should ensure that it will remain
operational
''' '''
class T_scipy(unittest.TestCase): class T_scipy(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
self.orig_floatX = theano.config.floatX self.orig_floatX = theano.config.floatX
def tearDown(self): def tearDown(self):
theano.config.floatX = self.orig_floatX theano.config.floatX = self.orig_floatX
def test_scipy_paper_example1(self): def test_scipy_paper_example1(self):
a = theano.tensor.vector('a') # declare variable a = theano.tensor.vector('a') # declare variable
b = a + a**10 # build expression b = a + a**10 # build expression
f = theano.function([a], b) # compile function f = theano.function([a], b) # compile function
assert numpy.all(f([0,1,2]) == numpy.array([0,2,1026])) assert numpy.all(f([0, 1, 2]) == numpy.array([0, 2, 1026]))
def test_scipy_paper_example2(self): def test_scipy_paper_example2(self):
''' This just sees if things compile well and if they run ''' ''' This just sees if things compile well and if they run '''
...@@ -34,7 +37,7 @@ class T_scipy(unittest.TestCase): ...@@ -34,7 +37,7 @@ class T_scipy(unittest.TestCase):
shared = theano.shared shared = theano.shared
function = theano.function function = theano.function
rng = numpy.random rng = numpy.random
theano.config.floatX='float64' theano.config.floatX = 'float64'
# #
# ACTUAL SCRIPT FROM PAPER # ACTUAL SCRIPT FROM PAPER
...@@ -49,18 +52,18 @@ class T_scipy(unittest.TestCase): ...@@ -49,18 +52,18 @@ class T_scipy(unittest.TestCase):
xent = -y*T.log(p_1) - (1-y)*T.log(1-p_1) xent = -y*T.log(p_1) - (1-y)*T.log(1-p_1)
prediction = p_1 > 0.5 prediction = p_1 > 0.5
cost = xent.mean() + 0.01*(w**2).sum() cost = xent.mean() + 0.01*(w**2).sum()
gw,gb = T.grad(cost, [w,b]) gw, gb = T.grad(cost, [w, b])
# Compile expressions to functions # Compile expressions to functions
train = function( train = function(
inputs=[x,y], inputs=[x, y],
outputs=[prediction, xent], outputs=[prediction, xent],
updates=[(w, w-0.1*gw), (b, b-0.1*gb)]) updates=[(w, w-0.1*gw), (b, b-0.1*gb)])
predict = function(inputs=[x], outputs=prediction) predict = function(inputs=[x], outputs=prediction)
N = 4 N = 4
feats = 100 feats = 100
D = (rng.randn(N, feats), rng.randint(size=4,low=0, high=2)) D = (rng.randn(N, feats), rng.randint(size=4, low=0, high=2))
training_steps = 10 training_steps = 10
for i in range(training_steps): for i in range(training_steps):
pred, err = train(D[0], D[1]) pred, err = train(D[0], D[1])
...@@ -68,4 +71,3 @@ class T_scipy(unittest.TestCase): ...@@ -68,4 +71,3 @@ class T_scipy(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论