提交 fb435772 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3511 from yingzha/ccw

Remove useless reshape
......@@ -3703,6 +3703,22 @@ register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain')
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Reshape])
def local_useless_reshape(node):
"""
Remove Reshape when both the input and the output have a
single dimension.
"""
if isinstance(node.op, T.Reshape):
if (node.inputs[0].ndim == 1 and node.outputs[0].ndim == 1 and
node.inputs[0].broadcastable ==
node.outputs[0].broadcastable):
return [node.inputs[0]]
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Reshape])
......
......@@ -3135,7 +3135,7 @@ def test_local_fill_useless():
assert T.Alloc in ops
f(m_, x_)
class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_local_useless_elemwise_comparison(self):
# TODO: test each case individually.
......@@ -3171,7 +3171,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
>Sum{acc_dtype=float64} [id L] ''
> |X[t] [id M] -> [id I]
"""
mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True)
......@@ -3211,7 +3211,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
......@@ -3224,10 +3224,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison')
f = theano.function([x], T.lt(x, x), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.le(x, x), mode=mode)
self.assert_eqs_const(f, 1)
......@@ -3289,10 +3289,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x, y], T.ge(x.shape[0]+y.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 1)
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode)
......@@ -5704,7 +5704,7 @@ def test_local_flatten_lift():
class Test_Reshape(unittest.TestCase):
def setUp(self):
self.mode = mode_opt
self.mode = mode_opt
self.op = tensor.Reshape
def test_local_reshape(self):
......@@ -5716,6 +5716,16 @@ class Test_Reshape(unittest.TestCase):
assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_useless_reshape():
mode = theano.compile.get_default_mode().including(
'local_useless_reshape')
i = T.iscalar('i')
m = theano.tensor.mgrid[0:i,]
f = theano.function([i], m, mode=mode)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_local_reshape_lift():
x = tensor.tensor4()
out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论