Small bugfix in sparse.Dot

上级 0e7173ff
...@@ -122,11 +122,12 @@ class _testCase_dot(unittest.TestCase): ...@@ -122,11 +122,12 @@ class _testCase_dot(unittest.TestCase):
x.data = x.data.T x.data = x.data.T
y.data = y.data.T y.data = y.data.T
zop = transpose(dot(y, x))
# zop = dot(y, x) # zop = dot(y, x)
zop = transpose(dot(y, x))
z = compile.eval_outputs([zop]) z = compile.eval_outputs([zop])
self.failUnless(z.shape == (500,2)) self.failUnless(z.shape == (500,2))
self.failUnless(type(z) is mtype) print mtype, type(z)
# self.failUnless(type(z) is mtype)
w = mtype((500,2)) w = mtype((500,2))
w[(10, 0)] = 3. w[(10, 0)] = 3.
...@@ -134,9 +135,11 @@ class _testCase_dot(unittest.TestCase): ...@@ -134,9 +135,11 @@ class _testCase_dot(unittest.TestCase):
w[(10, 1)] = 4 w[(10, 1)] = 4
w[(20, 1)] = 2 w[(20, 1)] = 2
self.failUnless(z.shape == w.shape) self.failUnless(z.shape == w.shape)
self.failUnless(type(z) == type(w)) # Type should switch from csr to csc and vice-versa, so don't perform this test
#self.failUnless(type(z) == type(w))
self.failUnless(z.dtype == w.dtype) self.failUnless(z.dtype == w.dtype)
# Type should switch from csr to csc and vice-versa, so don't perform this test
#self.failUnless(z == w) #self.failUnless(z == w)
self.failUnless(abs(z-w).nnz == 0) self.failUnless(abs(z-w).nnz == 0)
......
...@@ -229,4 +229,5 @@ def dot(x, y, grad_preserves_dense=True): ...@@ -229,4 +229,5 @@ def dot(x, y, grad_preserves_dense=True):
if x_is_sparse: if x_is_sparse:
return Dot(x,y,grad_preserves_dense).outputs[0] return Dot(x,y,grad_preserves_dense).outputs[0]
else: else:
return transpose(Dot(transpose(y), transpose(x), grad_preserves_dense).outputs[0]) assert y_is_sparse
return transpose(Dot(y.T, x.T, grad_preserves_dense).outputs[0])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论