提交 8cd6cebb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add more test cases to structureddot test

上级 332cd5d4
...@@ -441,6 +441,17 @@ class test_structureddot(unittest.TestCase): ...@@ -441,6 +441,17 @@ class test_structureddot(unittest.TestCase):
utt.verify_grad(buildgraphCSC, utt.verify_grad(buildgraphCSC,
[spmat.data, mat]) [spmat.data, mat])
def buildgraphCSC_T(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == 'float32'
rval = structured_dot(sym_mat.T, csc.T)
assert rval.type.dtype == 'float32'
return rval
utt.verify_grad(buildgraphCSC_T,
[spmat.data, mat])
def test_structureddot_csr_grad(self): def test_structureddot_csr_grad(self):
#shortcut: testing csc in float32, testing csr in float64 #shortcut: testing csc in float32, testing csr in float64
...@@ -461,6 +472,17 @@ class test_structureddot(unittest.TestCase): ...@@ -461,6 +472,17 @@ class test_structureddot(unittest.TestCase):
utt.verify_grad(buildgraph, utt.verify_grad(buildgraph,
[spmat.data, mat]) [spmat.data, mat])
def buildgraph_T(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csr.type.dtype == 'float64'
rval = structured_dot(sym_mat.T, csr.T)
assert rval.type.dtype == 'float64'
return rval
utt.verify_grad(buildgraph,
[spmat.data, mat])
def test_infer_shape_csr_csc_grad(self): def test_infer_shape_csr_csc_grad(self):
for sparsetype in ('csr', 'csc'): for sparsetype in ('csr', 'csc'):
a = SparseType(sparsetype, dtype=config.floatX)() a = SparseType(sparsetype, dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论