提交 8ab2661b authored 作者: Frederic's avatar Frederic

pep8

上级 b8fa2b75
......@@ -5,27 +5,30 @@ import theano
import numpy
import random
import numpy.random
from theano.tests import unittest_tools as utt
from theano.tests import unittest_tools as utt
'''
Different tests that are not connected to any particular Op, or functionality of
Theano. Here will go for example code that we will publish in papers, that we
should ensure that it will remain operational
Different tests that are not connected to any particular Op, or
functionality of Theano. Here will go for example code that we will
publish in papers, that we should ensure that it will remain
operational
'''
class T_scipy(unittest.TestCase):
def setUp(self):
utt.seed_rng()
self.orig_floatX = theano.config.floatX
def tearDown(self):
theano.config.floatX = self.orig_floatX
def test_scipy_paper_example1(self):
a = theano.tensor.vector('a') # declare variable
b = a + a**10 # build expression
f = theano.function([a], b) # compile function
assert numpy.all(f([0,1,2]) == numpy.array([0,2,1026]))
a = theano.tensor.vector('a') # declare variable
b = a + a**10 # build expression
f = theano.function([a], b) # compile function
assert numpy.all(f([0, 1, 2]) == numpy.array([0, 2, 1026]))
def test_scipy_paper_example2(self):
''' This just sees if things compile well and if they run '''
......@@ -34,7 +37,7 @@ class T_scipy(unittest.TestCase):
shared = theano.shared
function = theano.function
rng = numpy.random
theano.config.floatX='float64'
theano.config.floatX = 'float64'
#
# ACTUAL SCRIPT FROM PAPER
......@@ -49,18 +52,18 @@ class T_scipy(unittest.TestCase):
xent = -y*T.log(p_1) - (1-y)*T.log(1-p_1)
prediction = p_1 > 0.5
cost = xent.mean() + 0.01*(w**2).sum()
gw,gb = T.grad(cost, [w,b])
gw, gb = T.grad(cost, [w, b])
# Compile expressions to functions
train = function(
inputs=[x,y],
inputs=[x, y],
outputs=[prediction, xent],
updates=[(w, w-0.1*gw), (b, b-0.1*gb)])
predict = function(inputs=[x], outputs=prediction)
N = 4
feats = 100
D = (rng.randn(N, feats), rng.randint(size=4,low=0, high=2))
D = (rng.randn(N, feats), rng.randint(size=4, low=0, high=2))
training_steps = 10
for i in range(training_steps):
pred, err = train(D[0], D[1])
......@@ -68,4 +71,3 @@ class T_scipy(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论