提交 5d0baa3f authored 作者: Frederic Bastien's avatar Frederic Bastien

Make local_useless_reshape work with -1 in shapes

上级 c13a44b5
...@@ -4380,6 +4380,7 @@ def local_useless_reshape(node): ...@@ -4380,6 +4380,7 @@ def local_useless_reshape(node):
else: else:
shape_feature = getattr(node.fgraph, 'shape_feature', None) shape_feature = getattr(node.fgraph, 'shape_feature', None)
nb_m1 = 0
shape_match = [False] * input.ndim shape_match = [False] * input.ndim
for dim in xrange(input.ndim): for dim in xrange(input.ndim):
outshp_i = output_shape_is[dim] outshp_i = output_shape_is[dim]
...@@ -4403,11 +4404,17 @@ def local_useless_reshape(node): ...@@ -4403,11 +4404,17 @@ def local_useless_reshape(node):
continue continue
# Match 1 if input.broadcastable[dim] is True # Match 1 if input.broadcastable[dim] is True
if (input.broadcastable[dim] and cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
extract_constant(outshp_i, only_process_constants=1) == 1): if input.broadcastable[dim] and cst_outshp_i == 1:
shape_match[dim] = True shape_match[dim] = True
continue continue
# Match -1
if cst_outshp_i == -1:
shape_match[dim] = True
nb_m1 += 1
continue
# Match shape_of[input][dim] or its constant equivalent # Match shape_of[input][dim] or its constant equivalent
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(input, dim) inpshp_i = shape_feature.get_shape(input, dim)
...@@ -4417,7 +4424,7 @@ def local_useless_reshape(node): ...@@ -4417,7 +4424,7 @@ def local_useless_reshape(node):
shape_match[dim] = True shape_match[dim] = True
continue continue
if all(shape_match): if all(shape_match) and nb_m1 <= 1:
return [input] return [input]
# TODO later: if all the shapes except one match, we may want to # TODO later: if all the shapes except one match, we may want to
......
...@@ -6475,6 +6475,21 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6475,6 +6475,21 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_m1(self):
x = theano.tensor.matrix('x')
r = x.reshape((x.shape[0], -1))
m0 = theano.compile.get_default_mode()
m1 = m0.including('local_useless_reshape')
f1 = theano.function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
class Test_local_reshape_to_dimshuffle(unittest.TestCase): class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论