提交 fb435772 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3511 from yingzha/ccw

Remove useless reshape
...@@ -3703,6 +3703,22 @@ register_canonicalize(local_reshape_chain(T.Reshape), ...@@ -3703,6 +3703,22 @@ register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain') name='local_reshape_chain')
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Reshape])
def local_useless_reshape(node):
"""
Remove Reshape when both the input and the output have a
single dimension.
"""
if isinstance(node.op, T.Reshape):
if (node.inputs[0].ndim == 1 and node.outputs[0].ndim == 1 and
node.inputs[0].broadcastable ==
node.outputs[0].broadcastable):
return [node.inputs[0]]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([T.Reshape]) @gof.local_optimizer([T.Reshape])
......
...@@ -5716,6 +5716,16 @@ class Test_Reshape(unittest.TestCase): ...@@ -5716,6 +5716,16 @@ class Test_Reshape(unittest.TestCase):
assert sum(isinstance(node.op, self.op) for node in topo) == 1 assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_useless_reshape():
mode = theano.compile.get_default_mode().including(
'local_useless_reshape')
i = T.iscalar('i')
m = theano.tensor.mgrid[0:i,]
f = theano.function([i], m, mode=mode)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor.tensor4() x = tensor.tensor4()
out = T.exp(x).reshape([x.size]) out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论