提交 0edaad7b authored 作者: Frederic's avatar Frederic

Add sparse.dot test for infer_shape

上级 c316ae1c
...@@ -12,7 +12,7 @@ except ImportError: ...@@ -12,7 +12,7 @@ except ImportError:
import theano import theano
from theano import compile, config from theano import compile, config
from theano.sparse import enable_sparse from theano.sparse import enable_sparse
from theano.gof.python25 import product from theano.gof.python25 import all, product
if enable_sparse == False: if enable_sparse == False:
raise SkipTest('Optional package sparse disabled') raise SkipTest('Optional package sparse disabled')
...@@ -22,6 +22,7 @@ from theano.sparse.basic import _mtypes ...@@ -22,6 +22,7 @@ from theano.sparse.basic import _mtypes
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties, SparseType, StructuredDotCSC from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties, SparseType, StructuredDotCSC
from theano.sparse import add, mul, structured_dot, transpose from theano.sparse import add, mul, structured_dot, transpose
from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse
from theano.sparse import Dot, Usmm, UsmmCscDense
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import tensor from theano import tensor
...@@ -543,7 +544,18 @@ class DotTests(unittest.TestCase): ...@@ -543,7 +544,18 @@ class DotTests(unittest.TestCase):
f_a = theano.function([x, y], theano.sparse.dot(x, y)) f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y f_b = lambda x, y: x * y
assert abs(f_a(self.x_csr, self.y) - f_b(self.x_csr, self.y)).max() < 1e-4 assert _allclose(f_a(self.x_csr, self.y), f_b(self.x_csr, self.y))
# Test infer_shape
f_a = theano.function([x, y], theano.sparse.dot(x, y).shape)
f_b = lambda x, y: (x * y).shape
assert numpy.all(f_a(self.x_csr, self.y) == f_b(self.x_csr, self.y))
topo = f_a.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
def test_csc_dense(self): def test_csc_dense(self):
x = theano.sparse.csc_matrix('x') x = theano.sparse.csc_matrix('x')
...@@ -552,8 +564,18 @@ class DotTests(unittest.TestCase): ...@@ -552,8 +564,18 @@ class DotTests(unittest.TestCase):
f_a = theano.function([x, y], theano.sparse.dot(x, y)) f_a = theano.function([x, y], theano.sparse.dot(x, y))
f_b = lambda x, y: x * y f_b = lambda x, y: x * y
assert (abs(f_a(self.x_csc, self.y) - f_b(self.x_csc, self.y)).max() assert _allclose(f_a(self.x_csc, self.y), f_b(self.x_csc, self.y))
< 1e-4)
# Test infer_shape
f_a = theano.function([x, y], theano.sparse.dot(x, y).shape)
f_b = lambda x, y: (x * y).shape
assert numpy.all(f_a(self.x_csc, self.y) == f_b(self.x_csc, self.y))
topo = f_a.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
def test_sparse_sparse(self): def test_sparse_sparse(self):
for d1, d2 in [('float32', 'float32'), for d1, d2 in [('float32', 'float32'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论