提交 7dfa866c authored 作者: Frederic's avatar Frederic

fix sparse tests due to new opt in sparse/sandbox/sp2.py

上级 42809e8b
...@@ -244,7 +244,7 @@ class SparseInferShapeTester(utt.InferShapeTester): ...@@ -244,7 +244,7 @@ class SparseInferShapeTester(utt.InferShapeTester):
[sp.csr_matrix(random_lil((10, 40), [sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)), config.floatX, 3)),
numpy.random.randn(10, 40).astype(config.floatX)], numpy.random.randn(10, 40).astype(config.floatX)],
MulSD) MulSD, excluding=["local_mul_s_d"])
def test_remove0(self): def test_remove0(self):
x = SparseType('csr', dtype=config.floatX)() x = SparseType('csr', dtype=config.floatX)()
......
...@@ -170,10 +170,15 @@ class InferShapeTester(unittest.TestCase): ...@@ -170,10 +170,15 @@ class InferShapeTester(unittest.TestCase):
# optimizations, if we don't want to enumerate them explicitly. # optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize") self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls): def _compile_and_check(self, inputs, outputs, numeric_inputs, cls,
outputs_function = theano.function(inputs, outputs, mode=self.mode) excluding=None):
mode = self.mode
if excluding:
mode = mode.excluding(*excluding)
outputs_function = theano.function(inputs, outputs, mode=mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs], shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode) mode=mode)
#theano.printing.debugprint(shapes_function) #theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function. # Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort() topo_shape = shapes_function.maker.env.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论