提交 f1643043 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

added tests for csm_properties

上级 b11dd92e
...@@ -20,7 +20,8 @@ if enable_sparse == False: ...@@ -20,7 +20,8 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse.basic import verify_grad_sparse from theano.sparse.basic import verify_grad_sparse
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties from theano.sparse import (as_sparse_variable, CSC, CSR, CSM, CSMProperties,
csm_properties)
from theano.sparse import SparseType, CSMGrad from theano.sparse import SparseType, CSMGrad
from theano.sparse import StructuredDot, StructuredDotCSC from theano.sparse import StructuredDot, StructuredDotCSC
from theano.sparse import StructuredDotGradCSC, StructuredDotGradCSR from theano.sparse import StructuredDotGradCSC, StructuredDotGradCSR
...@@ -560,6 +561,49 @@ class T_conversion(unittest.TestCase): ...@@ -560,6 +561,49 @@ class T_conversion(unittest.TestCase):
self.assertRaises(TypeError, self.check_format_ndim, format, 4) self.assertRaises(TypeError, self.check_format_ndim, format, 4)
class test_csm_properties(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_csm_properties_grad(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
verify_grad_sparse(lambda *x: CSMProperties()(*x)[0], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[1], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[2], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[2], [spmat],
structured=True)
def test_csm_properties(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
x = SparseType(format, dtype=dtype)()
f = theano.function([x], csm_properties(x))
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
data, indices, indptr, shape = f(spmat)
assert numpy.all(data == spmat.data)
assert numpy.all(indices == spmat.indices)
assert numpy.all(indptr == spmat.indptr)
assert numpy.all(shape == spmat.shape)
class test_structureddot(unittest.TestCase): class test_structureddot(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论