提交 82368140 authored 作者: Frederic's avatar Frederic

Add computation tests for sparse.Dot with a vector as input.

上级 155e2eb9
......@@ -1054,42 +1054,37 @@ class DotTests(utt.InferShapeTester):
y = theano.tensor.matrix('y')
v = theano.tensor.vector('v')
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert _allclose(f_a(self.x_csr, self.y), f_b(self.x_csr, self.y))
# Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[self.x_csr, self.y],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([v, x], [theano.sparse.dot(v, x)],
[self.v_10, self.x_csr],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([x, v], [theano.sparse.dot(x, v)],
[self.x_csr, self.v_100],
(Dot, Usmm, UsmmCscDense))
for (x, y, x_v, y_v) in [(x, y, self.x_csr, self.y),
(x, v, self.x_csr, self.v_100),
(v, x, self.v_10, self.x_csr)]:
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert _allclose(f_a(x_v, y_v), f_b(x_v, y_v))
# Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[x_v, y_v],
(Dot, Usmm, UsmmCscDense))
def test_csc_dense(self):
x = theano.sparse.csc_matrix('x')
y = theano.tensor.matrix('y')
v = theano.tensor.vector('v')
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert _allclose(f_a(self.x_csc, self.y), f_b(self.x_csc, self.y))
# Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[self.x_csc, self.y],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([v, x], [theano.sparse.dot(v, x)],
[self.v_10, self.x_csc],
(Dot, Usmm, UsmmCscDense))
self._compile_and_check([x, v], [theano.sparse.dot(x, v)],
[self.x_csc, self.v_100],
(Dot, Usmm, UsmmCscDense))
for (x, y, x_v, y_v) in [(x, y, self.x_csc, self.y),
(x, v, self.x_csc, self.v_100),
(v, x, self.v_10, self.x_csc)]:
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert _allclose(f_a(x_v, y_v), f_b(x_v, y_v))
# Test infer_shape
self._compile_and_check([x, y], [theano.sparse.dot(x, y)],
[x_v, y_v],
(Dot, Usmm, UsmmCscDense))
def test_sparse_sparse(self):
for d1, d2 in [('float32', 'float32'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论