提交 67cd4234 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5815 from nouiz/useless_reshape

Make opt useless_reshape handle -1
......@@ -92,7 +92,8 @@ def test_flatten():
assert GpuReshape in [type(node.op)
for node in f.maker.fgraph.toposort()]
f = theano.function([m], m.flatten(ndim=2), mode=mode_with_gpu)
f = theano.function([m], m.flatten(ndim=2),
mode=mode_with_gpu.excluding("local_useless_reshape"))
val = np.random.rand(10, 11).astype("float32")
res = f(val)
utt.assert_allclose(res, val)
......
......@@ -4745,6 +4745,8 @@ class Reshape(Op):
def __init__(self, ndim, name=None):
self.ndim = ndim
if ndim < 0:
raise ValueError("The output dimensions after reshape must be 0 or greater")
assert name is None, 'name attribute for Reshape has been deprecated'
def __str__(self):
......
......@@ -4380,6 +4380,7 @@ def local_useless_reshape(node):
else:
shape_feature = getattr(node.fgraph, 'shape_feature', None)
nb_m1 = 0
shape_match = [False] * input.ndim
for dim in xrange(input.ndim):
outshp_i = output_shape_is[dim]
......@@ -4403,11 +4404,17 @@ def local_useless_reshape(node):
continue
# Match 1 if input.broadcastable[dim] is True
if (input.broadcastable[dim] and
extract_constant(outshp_i, only_process_constants=1) == 1):
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if input.broadcastable[dim] and cst_outshp_i == 1:
shape_match[dim] = True
continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(input, dim)
......@@ -4417,7 +4424,7 @@ def local_useless_reshape(node):
shape_match[dim] = True
continue
if all(shape_match):
if all(shape_match) and nb_m1 <= 1:
return [input]
# TODO later: if all the shapes except one match, we may want to
......
......@@ -7632,21 +7632,19 @@ class TestInferShape(utt.InferShapeTester):
# Flatten
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
self._compile_and_check([atens3],
[flatten(atens3, 1)],
[atens3_val], Reshape)
for outdim in (3, 2, 1):
self._compile_and_check([atens3],
[flatten(atens3, outdim)],
[atens3_val], Reshape)
[atens3_val], Reshape,
excluding=['local_useless_reshape'])
amat = matrix()
amat_val = rand(4, 5)
for outdim in (2, 1):
self._compile_and_check([amat],
[flatten(amat, outdim)],
[amat_val], Reshape)
[amat_val], Reshape,
excluding=['local_useless_reshape'])
avec = vector()
avec_val = rand(4)
......
......@@ -6475,6 +6475,21 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_m1(self):
x = theano.tensor.matrix('x')
r = x.reshape((x.shape[0], -1))
m0 = theano.compile.get_default_mode()
m1 = m0.including('local_useless_reshape')
f1 = theano.function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论