提交 fd7c30ce authored 作者: nouiz's avatar nouiz

Merge pull request #445 from lamblin/fix_test_structureddot_grad

Fix perform of structureddot_grad with dense grad
......@@ -1666,17 +1666,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
if sparse_A.type.format == 'csc':
sdgcsx = sdg_csc
else:
sdgcsx = sdg_csr
#backport
#sdgcsx = sdg_csc if sparse_A.type.format == 'csc' else sdg_csr
if sparse_A.type.format == 'csc':
CSx = CSC
else:
sdgcsx = sdg_csr
CSx = CSR
#backport
#CSx = CSC if sparse_A.type.format == 'csc' else CSR
g_A_data = sdgcsx(csm_indices(sparse_A), \
csm_indptr(sparse_A), dense_B, ga)
......@@ -1705,7 +1698,13 @@ class StructuredDotGradCSC(gof.Op):
ind1 = a_indptr[j + 1]
for i_idx in xrange(ind0, ind1):
i = a_indices[i_idx]
g_a_data[i_idx] = numpy.dot(g_ab[i], b[j].T)[0, 0]
# Depending on the type of g_ab and b (sparse or dense),
# the following dot product can result in a scalar or
# a (1, 1) sparse matrix.
dot_val = numpy.dot(g_ab[i], b[j].T)
if isinstance(dot_val, scipy.sparse.spmatrix):
dot_val = dot_val[0, 0]
g_a_data[i_idx] = dot_val
out[0] = g_a_data
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
......@@ -1820,7 +1819,13 @@ class StructuredDotGradCSR(gof.Op):
for j_idx in xrange(ind0, ind1):
j = a_indices[j_idx]
# grad is dot product of i-th row of gradient with j-th row of b
g_a_data[j_idx] = numpy.dot(g_ab[i], b[j].T)[0, 0]
# Depending on the type of g_ab and b (sparse or dense),
# the following dot product can result in a scalar or
# a (1, 1) sparse matrix.
dot_val = numpy.dot(g_ab[i], b[j].T)
if isinstance(dot_val, scipy.sparse.spmatrix):
dot_val = dot_val[0, 0]
g_a_data[j_idx] = dot_val
out[0] = g_a_data
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
......
......@@ -448,6 +448,17 @@ class test_structureddot(unittest.TestCase):
utt.verify_grad(buildgraphCSC,
[spmat.data, mat])
def buildgraphCSC_T(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == 'float32'
rval = structured_dot(sym_mat.T, csc.T)
assert rval.type.dtype == 'float32'
return rval
utt.verify_grad(buildgraphCSC_T,
[spmat.data, mat])
def test_structureddot_csr_grad(self):
#shortcut: testing csc in float32, testing csr in float64
......@@ -468,6 +479,17 @@ class test_structureddot(unittest.TestCase):
utt.verify_grad(buildgraph,
[spmat.data, mat])
def buildgraph_T(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csr.type.dtype == 'float64'
rval = structured_dot(sym_mat.T, csr.T)
assert rval.type.dtype == 'float64'
return rval
utt.verify_grad(buildgraph,
[spmat.data, mat])
def test_infer_shape_csr_csc_grad(self):
for sparsetype in ('csr', 'csc'):
a = SparseType(sparsetype, dtype=config.floatX)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论