提交 863f8167 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/compile/tests/test_misc.py

上级 1baba0e2
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy import numpy as np
import unittest import unittest
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
...@@ -20,8 +20,8 @@ class NNet(object): ...@@ -20,8 +20,8 @@ class NNet(object):
self.input = input self.input = input
self.target = target self.target = target
self.lr = shared(lr, 'learning_rate') self.lr = shared(lr, 'learning_rate')
self.w1 = shared(numpy.zeros((n_hidden, n_input)), 'w1') self.w1 = shared(np.zeros((n_hidden, n_input)), 'w1')
self.w2 = shared(numpy.zeros((n_output, n_hidden)), 'w2') self.w2 = shared(np.zeros((n_output, n_hidden)), 'w2')
# print self.lr.type # print self.lr.type
self.hidden = sigmoid(tensor.dot(self.w1, self.input)) self.hidden = sigmoid(tensor.dot(self.w1, self.input))
...@@ -45,7 +45,7 @@ class NNet(object): ...@@ -45,7 +45,7 @@ class NNet(object):
class TestNnet(unittest.TestCase): class TestNnet(unittest.TestCase):
def test_nnet(self): def test_nnet(self):
rng = numpy.random.RandomState(1827) rng = np.random.RandomState(1827)
data = rng.rand(10, 4) data = rng.rand(10, 4)
nnet = NNet(n_input=3, n_hidden=10) nnet = NNet(n_input=3, n_hidden=10)
for epoch in range(3): for epoch in range(3):
...@@ -60,4 +60,4 @@ class TestNnet(unittest.TestCase): ...@@ -60,4 +60,4 @@ class TestNnet(unittest.TestCase):
self.assertTrue(abs(mean_cost - 0.20588975452) < 1e-6) self.assertTrue(abs(mean_cost - 0.20588975452) < 1e-6)
# Just call functions to make sure they do not crash. # Just call functions to make sure they do not crash.
nnet.compute_output(input) nnet.compute_output(input)
nnet.output_from_hidden(numpy.ones(10)) nnet.output_from_hidden(np.ones(10))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论