Fixed test_matinv to work for any seed

上级 aeffaed7
...@@ -1305,14 +1305,26 @@ class test_matinv(unittest.TestCase): ...@@ -1305,14 +1305,26 @@ class test_matinv(unittest.TestCase):
ssd, gw = fn(x,w) ssd, gw = fn(x,w)
#print ssd, x*w, x, w #print ssd, x*w, x, w
if i == 0: if i == 0:
str0 = str(ssd) ssd0 = ssd
w -= 0.4 * gw w -= 0.4 * gw
return str0, str(ssd) return ssd0, ssd
def test_reciprocal(self): def test_reciprocal(self):
"""Matrix reciprocal by gradient descent""" """Matrix reciprocal by gradient descent"""
self.assertEqual(('6.10141615619', '0.00703816291711'), self.mat_reciprocal(3)) ssd0,ssd = self.mat_reciprocal(3)
numpy.random.seed(unittest_tools.fetch_seed(1))
# hand-coded numpy implementation for verification
x = numpy.random.rand(3,3)+0.1
w = numpy.random.rand(3,3)
myssd0 = numpy.sum((x*w - numpy.ones((3,3)))**2.0)
for i in xrange(300):
gw = 2*(x*w - numpy.ones((3,3)))*x # derivative of dMSE/dw
myssd = numpy.sum((x*w - numpy.ones((3,3)))**2)
w -= 0.4 * gw
self.failUnlessAlmostEqual(ssd0, myssd0)
self.failUnlessAlmostEqual(ssd, myssd)
class t_dot(unittest.TestCase): class t_dot(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论