提交 daf7d9c9 authored 作者: nouiz's avatar nouiz

Merge pull request #444 from lamblin/fix_sparse_infershape_tests

Include "canonicalize" in tests for shape opt.
......@@ -95,10 +95,15 @@ class T_transpose(unittest.TestCase):
class SparseInferShapeTester(unittest.TestCase):
def setUp(self):
utt.seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs)
shapes_function = theano.function(inputs, [o.shape for o in outputs])
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
......@@ -161,12 +166,13 @@ class SparseInferShapeTester(unittest.TestCase):
def test_add_sd(self):
x = SparseType('csr', dtype=config.floatX)()
y = tensor.matrix()
self._compile_and_check([x, y],
[x + y],
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
numpy.random.randn(10, 40)],
AddSD)
self._compile_and_check(
[x, y],
[x + y],
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
numpy.random.randn(10, 40).astype(config.floatX)],
AddSD)
def test_mul_ss(self):
x = SparseType('csr', dtype=config.floatX)()
......@@ -181,12 +187,13 @@ class SparseInferShapeTester(unittest.TestCase):
def test_mul_sd(self):
x = SparseType('csr', dtype=config.floatX)()
y = tensor.matrix()
self._compile_and_check([x, y],
[x * y],
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
numpy.random.randn(10, 40)],
MulSD)
self._compile_and_check(
[x, y],
[x * y],
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
numpy.random.randn(10, 40).astype(config.floatX)],
MulSD)
class T_AddMul(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论