提交 1e2d5bca authored 作者: Frederic's avatar Frederic

Add test for opt local_flatten_lift

上级 5eb8077a
...@@ -3988,6 +3988,22 @@ def test_local_div_to_inv(): ...@@ -3988,6 +3988,22 @@ def test_local_div_to_inv():
assert numpy.allclose(out_val, 0.5) assert numpy.allclose(out_val, 0.5)
def test_local_flatten_lift():
for i in range(1, 4):
op = tensor.Flatten(i)
x = tensor.tensor4()
out = op(T.exp(x))
assert out.ndim == i
mode = compile.mode.get_default_mode()
mode = mode.including('local_flatten_lift')
f = theano.function([x], out, mode=mode)
f(numpy.random.rand(5, 4, 3, 2).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, tensor.Flatten)
assert isinstance(topo[1].op, tensor.Elemwise)
class Test_lift_transpose_through_dot(unittest.TestCase): class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g): def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g) out2in(opt.local_useless_elemwise).optimize(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论