提交 9e0e8bbf authored 作者: Frederic's avatar Frederic

Make Usmm test faster and remove deprecated warning.

上级 8741d36a
...@@ -613,9 +613,10 @@ class DotTests(unittest.TestCase): ...@@ -613,9 +613,10 @@ class DotTests(unittest.TestCase):
class UsmmTests(unittest.TestCase): class UsmmTests(unittest.TestCase):
""" Test the Usmm and UsmmCscDense class and related optimization """
def setUp(self): def setUp(self):
x_size = (10, 200) x_size = (10, 100)
y_size = (200, 2000) y_size = (100, 200)
z_size = (x_size[0], y_size[1]) z_size = (x_size[0], y_size[1])
self.x = numpy.asarray(numpy.random.binomial(1, 0.5, x_size), dtype=theano.config.floatX) self.x = numpy.asarray(numpy.random.binomial(1, 0.5, x_size), dtype=theano.config.floatX)
...@@ -639,9 +640,7 @@ class UsmmTests(unittest.TestCase): ...@@ -639,9 +640,7 @@ class UsmmTests(unittest.TestCase):
x = mat(format1, 'x', dtype1) x = mat(format1, 'x', dtype1)
y = mat(format2, 'y', dtype2) y = mat(format2, 'y', dtype2)
a = theano.tensor.scalar('a', dtype=dtype3) a = theano.tensor.scalar('a', dtype=dtype3)
z = theano.tensor.shared( z = theano.shared(numpy.asarray(self.z, dtype=dtype4).copy())
numpy.asarray(self.z, dtype=dtype4).copy()
)
f_b = lambda z, a, x, y: z - a * (x * y) f_b = lambda z, a, x, y: z - a * (x * y)
x_data = numpy.asarray(self.x, dtype=dtype1) x_data = numpy.asarray(self.x, dtype=dtype1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论