提交 76c598ee authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

added tests for CSM op and gradient

上级 a5b7de8b
......@@ -22,7 +22,7 @@ from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse.basic import verify_grad_sparse
from theano.sparse import (as_sparse_variable, CSC, CSR, CSM, CSMProperties,
csm_properties)
from theano.sparse import SparseType, CSMGrad
from theano.sparse import SparseType, CSMGrad, CSMGradC
from theano.sparse import StructuredDot, StructuredDotCSC
from theano.sparse import StructuredDotGradCSC, StructuredDotGradCSR
from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg, Remove0
......@@ -184,7 +184,7 @@ class SparseInferShapeTester(utt.InferShapeTester):
[out],
[spm.data, spm.indices, spm.indptr,
spm.shape],
CSMGrad
(CSMGrad, CSMGradC)
)
def test_transpose(self):
......@@ -616,6 +616,45 @@ class test_csm_properties(unittest.TestCase):
assert numpy.all(shape == spmat.shape)
class test_csm(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_csm_grad(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
verify_grad_sparse(lambda x: CSM(format)(x, spmat.indices,
spmat.indptr, numpy.asarray(spmat.shape, 'int32')),
[spmat.data], structured=True)
def test_csm(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
x = tensor.tensor(dtype=dtype, broadcastable=(False,))
y = tensor.ivector()
z = tensor.ivector()
s = tensor.ivector()
f = theano.function([x, y, z, s], CSM(format)(x, y, z, s))
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
res = f(spmat.data, spmat.indices, spmat.indptr,
numpy.asarray(spmat.shape, 'int32'))
assert numpy.all(res.data == spmat.data)
assert numpy.all(res.indices == spmat.indices)
assert numpy.all(res.indptr == spmat.indptr)
assert numpy.all(res.shape == spmat.shape)
class test_structureddot(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论