提交 f4b3c8ad authored 作者: nouiz's avatar nouiz

Merge pull request #605 from ynd/csm_fix

bugfix gradient of csm_properties
...@@ -558,14 +558,10 @@ class CSMProperties(gof.Op): ...@@ -558,14 +558,10 @@ class CSMProperties(gof.Op):
out[2][0] = theano._asarray(csm.indptr, dtype='int32') out[2][0] = theano._asarray(csm.indptr, dtype='int32')
out[3][0] = theano._asarray(csm.shape, dtype='int32') out[3][0] = theano._asarray(csm.shape, dtype='int32')
# TODO FIX THIS
def grad(self, (csm,), g): def grad(self, (csm,), g):
assert [gg is None for gg in g[1:]] assert [gg is None for gg in g[1:]]
data, indices, indptr, shape = csm_properties(csm) data, indices, indptr, shape = csm_properties(csm)
if csm.format == 'csc': return [CSM(csm.format)(g[0], indices, indptr, shape)]
return [CSM('csc')(g_data, indices, indptr, shape)]
else:
return [CSR('csm')(g_data, indices, indptr, shape)]
# don't make this a function or it breaks some optimizations below # don't make this a function or it breaks some optimizations below
csm_properties = CSMProperties() csm_properties = CSMProperties()
......
...@@ -20,7 +20,8 @@ if enable_sparse == False: ...@@ -20,7 +20,8 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse.basic import verify_grad_sparse from theano.sparse.basic import verify_grad_sparse
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties from theano.sparse import (as_sparse_variable, CSC, CSR, CSM, CSMProperties,
csm_properties)
from theano.sparse import SparseType, CSMGrad from theano.sparse import SparseType, CSMGrad
from theano.sparse import StructuredDot, StructuredDotCSC from theano.sparse import StructuredDot, StructuredDotCSC
from theano.sparse import StructuredDotGradCSC, StructuredDotGradCSR from theano.sparse import StructuredDotGradCSC, StructuredDotGradCSR
...@@ -560,6 +561,49 @@ class T_conversion(unittest.TestCase): ...@@ -560,6 +561,49 @@ class T_conversion(unittest.TestCase):
self.assertRaises(TypeError, self.check_format_ndim, format, 4) self.assertRaises(TypeError, self.check_format_ndim, format, 4)
class test_csm_properties(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_csm_properties_grad(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
verify_grad_sparse(lambda *x: CSMProperties()(*x)[0], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[1], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[2], [spmat],
structured=True)
verify_grad_sparse(lambda *x: CSMProperties()(*x)[2], [spmat],
structured=True)
def test_csm_properties(self):
sp_types = {'csc': sp.csc_matrix,
'csr': sp.csr_matrix}
for format in ['csc', 'csr']:
for dtype in ['float32', 'float64']:
x = SparseType(format, dtype=dtype)()
f = theano.function([x], csm_properties(x))
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
data, indices, indptr, shape = f(spmat)
assert numpy.all(data == spmat.data)
assert numpy.all(indices == spmat.indices)
assert numpy.all(indptr == spmat.indptr)
assert numpy.all(shape == spmat.shape)
class test_structureddot(unittest.TestCase): class test_structureddot(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论