提交 af6b947a authored 作者: David Warde-Farley's avatar David Warde-Farley

Fix test utility function.

上级 6d7c423a
......@@ -20,7 +20,8 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties
from theano.sparse import SparseType, StructuredDotCSC
from theano.sparse import SparseType, StructuredDotCSC, CSMGrad
from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg
from theano.sparse import add, mul, structured_dot, transpose
from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse
from theano.sparse import Dot, Usmm, UsmmCscDense
......@@ -95,12 +96,13 @@ class SparseInferShapeTester(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def _compile_and_check(self, inputs, outputs, numeric_inputs):
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs)
shapes_function = theano.function(inputs, [o.shape for o in outputs])
# Check that the Op is removed from the compiled function.
topo = shapes_function.maker.env.toposort()
assert not any(isinstance(t, self.__class__) for t in topo)
assert not any(isinstance(t, cls) for t in topo)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
......@@ -117,15 +119,16 @@ class SparseInferShapeTester(unittest.TestCase):
self._compile_and_check([x],
[x.T],
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))])
config.floatX, 3))],
Transpose)
def test_neg(self):
x = SparseType('csr', dtype=config.floatX)()
self._compile_and_check([x],
[-x],
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))])
config.floatX, 3))],
Neg)
def test_add_ss(self):
x = SparseType('csr', dtype=config.floatX)()
......@@ -135,7 +138,8 @@ class SparseInferShapeTester(unittest.TestCase):
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)),
sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))])
config.floatX, 3))],
AddSS)
def test_add_sd(self):
x = SparseType('csr', dtype=config.floatX)()
......@@ -144,7 +148,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x + y],
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)),
numpy.random.randn(10, 10)])
numpy.random.randn(10, 10)],
AddSD)
def test_mul_ss(self):
x = SparseType('csr', dtype=config.floatX)()
......@@ -153,7 +158,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x * y],
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)),
] * 2)
] * 2,
MulSS)
def test_mul_sd(self):
x = SparseType('csr', dtype=config.floatX)()
......@@ -162,7 +168,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x * y],
[sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)),
numpy.random.randn(10, 10)])
numpy.random.randn(10, 10)],
MulSD)
class T_AddMul(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论