提交 a4813f37 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the new verify_grad_sparse (if only to make sure it works).

上级 b30ebcdb
...@@ -19,6 +19,7 @@ if enable_sparse == False: ...@@ -19,6 +19,7 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse.basic import verify_grad_sparse
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties
from theano.sparse import SparseType, StructuredDotCSC, CSMGrad from theano.sparse import SparseType, StructuredDotCSC, CSMGrad
from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg
...@@ -462,16 +463,8 @@ class test_structureddot(unittest.TestCase): ...@@ -462,16 +463,8 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float32') mat = numpy.asarray(numpy.random.randn(3, 2), 'float32')
def buildgraphCSC(spdata, sym_mat): verify_grad_sparse(structured_dot, [spmat, mat])
csc = CSC(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == 'float32'
rval = structured_dot(csc, sym_mat)
assert rval.type.dtype == 'float32'
return rval
utt.verify_grad(buildgraphCSC,
[spmat.data, mat])
def buildgraphCSC_T(spdata, sym_mat): def buildgraphCSC_T(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size], csc = CSC(spdata, spmat.indices[:spmat.size],
...@@ -493,16 +486,7 @@ class test_structureddot(unittest.TestCase): ...@@ -493,16 +486,7 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float64') mat = numpy.asarray(numpy.random.randn(3, 2), 'float64')
def buildgraph(spdata, sym_mat): verify_grad_sparse(structured_dot, [spmat, mat])
csr = CSR(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csr.type.dtype == 'float64'
rval = structured_dot(csr, sym_mat)
assert rval.type.dtype == 'float64'
return rval
utt.verify_grad(buildgraph,
[spmat.data, mat])
def buildgraph_T(spdata, sym_mat): def buildgraph_T(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size], csr = CSR(spdata, spmat.indices[:spmat.size],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论