提交 09944cfe authored 作者: Xavier Bouthillier's avatar Xavier Bouthillier

Merge pull request #2780 from ejls/master

Add an optimization that remove Split with only 1 split. (Close #2545)
...@@ -3077,6 +3077,25 @@ def local_useless_tile(node): ...@@ -3077,6 +3077,25 @@ def local_useless_tile(node):
return return
##############
# Split Opts #
##############
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Split])
def local_useless_split(node):
""" Split{n_splits=1}(x, y) -> x
Remove Split with only 1 split.
"""
if isinstance(node.op, T.Split):
if node.op.len_splits == 1:
x, axis, splits = node.inputs
out = assert_op(x, T.eq(splits.shape[0], 1))
out = assert_op(out, T.eq(x.shape[axis], splits[0]))
return [out]
################ ################
# Flatten Opts # # Flatten Opts #
################ ################
......
...@@ -5013,6 +5013,26 @@ def test_local_div_to_inv(): ...@@ -5013,6 +5013,26 @@ def test_local_div_to_inv():
assert numpy.allclose(out_val, 0.5) assert numpy.allclose(out_val, 0.5)
def test_local_useless_split():
x = tensor.matrix('x')
splits = tensor.ivector('splits')
opt = tensor.split(x, splits, n_splits=1)
nonopt = tensor.split(x, splits, n_splits=3)
mode = compile.get_default_mode().including("local_useless_split")
f_opt = theano.function([x, splits], opt, mode=mode)
f_nonopt = theano.function([x, splits], nonopt, mode=mode)
f_opt(numpy.random.rand(4,4).astype(config.floatX), [4])
f_nonopt(numpy.random.rand(4,4).astype(config.floatX), [1,2,1])
graph_opt = f_opt.maker.fgraph.toposort()
graph_nonopt = f_nonopt.maker.fgraph.toposort()
assert isinstance(graph_opt[-1].op, DeepCopyOp)
assert len(graph_nonopt)==1
assert isinstance(graph_nonopt[0].op, tensor.Split)
def test_local_flatten_lift(): def test_local_flatten_lift():
for i in range(1, 4): for i in range(1, 4):
op = tensor.Flatten(i) op = tensor.Flatten(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论