提交 c2b462ed authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix perform of structureddot_grad with dense grad

上级 8cd6cebb
......@@ -1705,7 +1705,13 @@ class StructuredDotGradCSC(gof.Op):
ind1 = a_indptr[j + 1]
for i_idx in xrange(ind0, ind1):
i = a_indices[i_idx]
g_a_data[i_idx] = numpy.dot(g_ab[i], b[j].T)[0, 0]
# Depending on the type of g_ab and b (sparse or dense),
# the following dot product can result in a scalar or
# a (1, 1) sparse matrix.
dot_val = numpy.dot(g_ab[i], b[j].T)
if isinstance(dot_val, scipy.sparse.spmatrix):
dot_val = dot_val[0, 0]
g_a_data[i_idx] = dot_val
out[0] = g_a_data
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
......@@ -1820,7 +1826,13 @@ class StructuredDotGradCSR(gof.Op):
for j_idx in xrange(ind0, ind1):
j = a_indices[j_idx]
# grad is dot product of i-th row of gradient with j-th row of b
g_a_data[j_idx] = numpy.dot(g_ab[i], b[j].T)[0, 0]
# Depending on the type of g_ab and b (sparse or dense),
# the following dot product can result in a scalar or
# a (1, 1) sparse matrix.
dot_val = numpy.dot(g_ab[i], b[j].T)
if isinstance(dot_val, scipy.sparse.spmatrix):
dot_val = dot_val[0, 0]
g_a_data[j_idx] = dot_val
out[0] = g_a_data
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论