提交 af6b947a authored 作者: David Warde-Farley's avatar David Warde-Farley

Fix test utility function.

上级 6d7c423a
...@@ -20,7 +20,8 @@ if enable_sparse == False: ...@@ -20,7 +20,8 @@ if enable_sparse == False:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties from theano.sparse import as_sparse_variable, CSC, CSR, CSM, CSMProperties
from theano.sparse import SparseType, StructuredDotCSC from theano.sparse import SparseType, StructuredDotCSC, CSMGrad
from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg
from theano.sparse import add, mul, structured_dot, transpose from theano.sparse import add, mul, structured_dot, transpose
from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse from theano.sparse import csc_from_dense, csr_from_dense, dense_from_sparse
from theano.sparse import Dot, Usmm, UsmmCscDense from theano.sparse import Dot, Usmm, UsmmCscDense
...@@ -95,12 +96,13 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -95,12 +96,13 @@ class SparseInferShapeTester(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def _compile_and_check(self, inputs, outputs, numeric_inputs): def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs) outputs_function = theano.function(inputs, outputs)
shapes_function = theano.function(inputs, [o.shape for o in outputs]) shapes_function = theano.function(inputs, [o.shape for o in outputs])
# Check that the Op is removed from the compiled function.
topo = shapes_function.maker.env.toposort() topo = shapes_function.maker.env.toposort()
assert not any(isinstance(t, self.__class__) for t in topo) assert not any(isinstance(t, cls) for t in topo)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs) numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs) numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes): for out, shape in zip(numeric_outputs, numeric_shapes):
...@@ -117,15 +119,16 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -117,15 +119,16 @@ class SparseInferShapeTester(unittest.TestCase):
self._compile_and_check([x], self._compile_and_check([x],
[x.T], [x.T],
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))]) config.floatX, 3))],
Transpose)
def test_neg(self): def test_neg(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
self._compile_and_check([x], self._compile_and_check([x],
[-x], [-x],
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))]) config.floatX, 3))],
Neg)
def test_add_ss(self): def test_add_ss(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
...@@ -135,7 +138,8 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -135,7 +138,8 @@ class SparseInferShapeTester(unittest.TestCase):
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)), config.floatX, 3)),
sp.csr_matrix(random_lil((10, 10), sp.csr_matrix(random_lil((10, 10),
config.floatX, 3))]) config.floatX, 3))],
AddSS)
def test_add_sd(self): def test_add_sd(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
...@@ -144,7 +148,8 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -144,7 +148,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x + y], [x + y],
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)), config.floatX, 3)),
numpy.random.randn(10, 10)]) numpy.random.randn(10, 10)],
AddSD)
def test_mul_ss(self): def test_mul_ss(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
...@@ -153,7 +158,8 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -153,7 +158,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x * y], [x * y],
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)), config.floatX, 3)),
] * 2) ] * 2,
MulSS)
def test_mul_sd(self): def test_mul_sd(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
...@@ -162,7 +168,8 @@ class SparseInferShapeTester(unittest.TestCase): ...@@ -162,7 +168,8 @@ class SparseInferShapeTester(unittest.TestCase):
[x * y], [x * y],
[sp.csr_matrix(random_lil((10, 10), [sp.csr_matrix(random_lil((10, 10),
config.floatX, 3)), config.floatX, 3)),
numpy.random.randn(10, 10)]) numpy.random.randn(10, 10)],
MulSD)
class T_AddMul(unittest.TestCase): class T_AddMul(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论