提交 a4813f37 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use the new verify_grad_sparse (if only to make sure it works).

上级 b30ebcdb
......@@ -19,6 +19,7 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse.basic import verify_grad_sparse
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties
from theano.sparse import SparseType, StructuredDotCSC, CSMGrad
from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg
......@@ -462,16 +463,8 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float32')
def buildgraphCSC(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == 'float32'
rval = structured_dot(csc, sym_mat)
assert rval.type.dtype == 'float32'
return rval
verify_grad_sparse(structured_dot, [spmat, mat])
utt.verify_grad(buildgraphCSC,
[spmat.data, mat])
def buildgraphCSC_T(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size],
......@@ -493,16 +486,7 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float64')
def buildgraph(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csr.type.dtype == 'float64'
rval = structured_dot(csr, sym_mat)
assert rval.type.dtype == 'float64'
return rval
utt.verify_grad(buildgraph,
[spmat.data, mat])
verify_grad_sparse(structured_dot, [spmat, mat])
def buildgraph_T(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论