提交 810415ed authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update sparse test to test more combinations.

上级 55cd62b2
...@@ -347,9 +347,13 @@ class test_structureddot(unittest.TestCase): ...@@ -347,9 +347,13 @@ class test_structureddot(unittest.TestCase):
def test_dot_sparse_sparse(self): def test_dot_sparse_sparse(self):
#test dot for 2 input sparse matrix #test dot for 2 input sparse matrix
sparse_dtype = 'float64' sparse_dtype = 'float64'
for sparse_format in ['csc','csr']: sp_mat = {'csc':sp.csc_matrix,
a = SparseType(sparse_format, dtype=sparse_dtype)() 'csr':sp.csr_matrix}
b = SparseType(sparse_format, dtype=sparse_dtype)()
for sparse_format_a in ['csc','csr']:
for sparse_format_b in ['csc', 'csr']:
a = SparseType(sparse_format_a, dtype=sparse_dtype)()
b = SparseType(sparse_format_b, dtype=sparse_dtype)()
d = theano.dot(a,b) d = theano.dot(a,b)
f = theano.function([a,b], theano.Out(d, borrow=True)) f = theano.function([a,b], theano.Out(d, borrow=True))
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
...@@ -358,13 +362,9 @@ class test_structureddot(unittest.TestCase): ...@@ -358,13 +362,9 @@ class test_structureddot(unittest.TestCase):
(40,30,20,30), (40,30,20,30),
(400,3000,200,6000), (400,3000,200,6000),
]: ]:
if sparse_format == 'csc': a_val = sp_mat[sparse_format_a](random_lil((M,N), sparse_dtype, nnz))
spmat = sp.csc_matrix(random_lil((M,N), sparse_dtype, nnz)) b_val = sp_mat[sparse_format_b](random_lil((N,K), sparse_dtype, nnz))
spmat2 = sp.csc_matrix(random_lil((N,K), sparse_dtype, nnz)) f(a_val, b_val)
elif sparse_format == 'csr':
spmat = sp.csr_matrix(random_lil((M,N), sparse_dtype, nnz))
spmat2 = sp.csr_matrix(random_lil((N,K), sparse_dtype, nnz))
f(spmat,spmat2)
def test_csc_correct_output_faster_than_scipy(self): def test_csc_correct_output_faster_than_scipy(self):
sparse_dtype = 'float64' sparse_dtype = 'float64'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论