Small bugfix in sparse.Dot

上级 0e7173ff
......@@ -122,11 +122,12 @@ class _testCase_dot(unittest.TestCase):
x.data = x.data.T
y.data = y.data.T
zop = transpose(dot(y, x))
# zop = dot(y, x)
zop = transpose(dot(y, x))
z = compile.eval_outputs([zop])
self.failUnless(z.shape == (500,2))
self.failUnless(type(z) is mtype)
print mtype, type(z)
# self.failUnless(type(z) is mtype)
w = mtype((500,2))
w[(10, 0)] = 3.
......@@ -134,9 +135,11 @@ class _testCase_dot(unittest.TestCase):
w[(10, 1)] = 4
w[(20, 1)] = 2
self.failUnless(z.shape == w.shape)
self.failUnless(type(z) == type(w))
# Type should switch from csr to csc and vice-versa, so don't perform this test
#self.failUnless(type(z) == type(w))
self.failUnless(z.dtype == w.dtype)
# Type should switch from csr to csc and vice-versa, so don't perform this test
#self.failUnless(z == w)
self.failUnless(abs(z-w).nnz == 0)
......
......@@ -229,4 +229,5 @@ def dot(x, y, grad_preserves_dense=True):
if x_is_sparse:
return Dot(x,y,grad_preserves_dense).outputs[0]
else:
return transpose(Dot(transpose(y), transpose(x), grad_preserves_dense).outputs[0])
assert y_is_sparse
return transpose(Dot(y.T, x.T, grad_preserves_dense).outputs[0])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论