提交 82368140 authored 作者: Frederic's avatar Frederic

Add computation tests for sparse.Dot with a vector as input.

上级 155e2eb9
...@@ -1054,20 +1054,17 @@ class DotTests(utt.InferShapeTester): ...@@ -1054,20 +1054,17 @@ class DotTests(utt.InferShapeTester):
y = theano.tensor.matrix('y') y = theano.tensor.matrix('y')
v = theano.tensor.vector('v') v = theano.tensor.vector('v')
for (x, y, x_v, y_v) in [(x, y, self.x_csr, self.y),
(x, v, self.x_csr, self.v_100),
(v, x, self.v_10, self.x_csr)]:
f_a = theano.function([x, y], theano.sparse.dot(x, y)) f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y f_b = lambda x, y: x * y
assert _allclose(f_a(self.x_csr, self.y), f_b(self.x_csr, self.y)) assert _allclose(f_a(x_v, y_v), f_b(x_v, y_v))
# Test infer_shape # Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)], self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[self.x_csr, self.y], [x_v, y_v],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([v, x], [theano.sparse.dot(v, x)],
[self.v_10, self.x_csr],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([x, v], [theano.sparse.dot(x, v)],
[self.x_csr, self.v_100],
(Dot, Usmm, UsmmCscDense)) (Dot, Usmm, UsmmCscDense))
def test_csc_dense(self): def test_csc_dense(self):
...@@ -1075,20 +1072,18 @@ class DotTests(utt.InferShapeTester): ...@@ -1075,20 +1072,18 @@ class DotTests(utt.InferShapeTester):
y = theano.tensor.matrix('y') y = theano.tensor.matrix('y')
v = theano.tensor.vector('v') v = theano.tensor.vector('v')
for (x, y, x_v, y_v) in [(x, y, self.x_csc, self.y),
(x, v, self.x_csc, self.v_100),
(v, x, self.v_10, self.x_csc)]:
f_a = theano.function([x, y], theano.sparse.dot(x, y)) f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y f_b = lambda x, y: x * y
assert _allclose(f_a(self.x_csc, self.y), f_b(self.x_csc, self.y)) assert _allclose(f_a(x_v, y_v), f_b(x_v, y_v))
# Test infer_shape # Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)], self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[self.x_csc, self.y], [x_v, y_v],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([v, x], [theano.sparse.dot(v, x)],
[self.v_10, self.x_csc],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([x, v], [theano.sparse.dot(x, v)],
[self.x_csc, self.v_100],
(Dot, Usmm, UsmmCscDense)) (Dot, Usmm, UsmmCscDense))
def test_sparse_sparse(self): def test_sparse_sparse(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论