提交 0edaad7b authored 作者: Frederic's avatar Frederic

Add sparse.dot test for infer_shape

上级 c316ae1c
......@@ -12,7 +12,7 @@ except ImportError:
import theano
from theano import compile, config
from theano.sparse import enable_sparse
from theano.gof.python25 import product
from theano.gof.python25 import all, product
if enable_sparse == False:
raise SkipTest('Optional package sparse disabled')
......@@ -22,6 +22,7 @@ from theano.sparse.basic import _mtypes
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties, SparseType, StructuredDotCSC
from theano.sparse import add, mul, structured_dot, transpose
from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse
from theano.sparse import Dot, Usmm, UsmmCscDense
from theano.tests import unittest_tools as utt
from theano import tensor
......@@ -543,7 +544,18 @@ class DotTests(unittest.TestCase):
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert abs(f_a(self.x_csr, self.y) - f_b(self.x_csr, self.y)).max() < 1e-4
assert _allclose(f_a(self.x_csr, self.y), f_b(self.x_csr, self.y))
# Test infer_shape
f_a = theano.function([x, y], theano.sparse.dot(x, y).shape)
f_b = lambda x, y: (x * y).shape
assert numpy.all(f_a(self.x_csr, self.y) == f_b(self.x_csr, self.y))
topo = f_a.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
def test_csc_dense(self):
x = theano.sparse.csc_matrix('x')
......@@ -552,8 +564,18 @@ class DotTests(unittest.TestCase):
f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y
assert (abs(f_a(self.x_csc, self.y) - f_b(self.x_csc, self.y)).max()
< 1e-4)
assert _allclose(f_a(self.x_csc, self.y), f_b(self.x_csc, self.y))
# Test infer_shape
f_a = theano.function([x, y], theano.sparse.dot(x, y).shape)
f_b = lambda x, y: (x * y).shape
assert numpy.all(f_a(self.x_csc, self.y) == f_b(self.x_csc, self.y))
topo = f_a.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
def test_sparse_sparse(self):
for d1, d2 in [('float32', 'float32'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论