提交 9bf98a6f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Address issues after code review

上级 724412fa
...@@ -4200,7 +4200,8 @@ def local_useless_reshape(node): ...@@ -4200,7 +4200,8 @@ def local_useless_reshape(node):
if input.ndim != output.ndim: if input.ndim != output.ndim:
return False return False
# Simple case: both input and output have a single dimension # Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if (input.ndim == 1 and output.ndim == 1 and if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable): input.broadcastable == output.broadcastable):
return [input] return [input]
...@@ -4277,7 +4278,6 @@ def local_reshape_to_dimshuffle(node): ...@@ -4277,7 +4278,6 @@ def local_reshape_to_dimshuffle(node):
or be removed later on. or be removed later on.
For example: For example:
- reshape(v, (m,)) --> v # if v.ndim == 1
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,)) - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1)) - reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n))) --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
...@@ -4295,7 +4295,11 @@ def local_reshape_to_dimshuffle(node): ...@@ -4295,7 +4295,11 @@ def local_reshape_to_dimshuffle(node):
new_output_shape = [] new_output_shape = []
index = 0 # index over the output of the new reshape index = 0 # index over the output of the new reshape
for i in xrange(output.ndim): for i in xrange(output.ndim):
dim = extract_constant(output_shape[i], only_process_constants=False) # Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(output_shape[i], only_process_constants=False,
elemwise=False)
if dim == 1: if dim == 1:
dimshuffle_new_order.append('x') dimshuffle_new_order.append('x')
else: else:
......
...@@ -5124,6 +5124,7 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5124,6 +5124,7 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
assert len(topo_) <= 1, topo_ assert len(topo_) <= 1, topo_
else: else:
assert len(topo_) == 1, topo_ assert len(topo_) == 1, topo_
if len(topo_) > 0:
assert type(topo_[0].op) is self.op assert type(topo_[0].op) is self.op
return f return f
......
...@@ -225,7 +225,7 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -225,7 +225,7 @@ class test_dimshuffle_lift(unittest.TestCase):
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle_in_reshape(): def test_local_useless_dimshuffle_in_reshape():
vector = TensorType(broadcastable=(False,), dtype='float64')('vector') vector = TensorType(broadcastable=(False,), dtype='float64')('vector')
mat = TensorType(broadcastable=(False, False), dtype='float64')('mat') mat = TensorType(broadcastable=(False, False), dtype='float64')('mat')
row = TensorType(broadcastable=(True, False), dtype='float64')('row') row = TensorType(broadcastable=(True, False), dtype='float64')('row')
...@@ -250,8 +250,17 @@ def test_useless_dimshuffle_in_reshape(): ...@@ -250,8 +250,17 @@ def test_useless_dimshuffle_in_reshape():
"Reshape{2}(mat, Shape(mat)), " "Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), " "Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]") "Reshape{2}(col, Shape(col))]")
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert_true(hasattr(g.outputs[0].tag, 'trace')) assert_true(check_stack_trace(g, ops_to_check='all'))
# Check that the optimization does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = tensor.reshape(mat.dimshuffle('x', 1, 'x', 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h)
assert_true(str(h) == str(h))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
...@@ -6204,6 +6213,21 @@ class Test_local_useless_reshape(unittest.TestCase): ...@@ -6204,6 +6213,21 @@ class Test_local_useless_reshape(unittest.TestCase):
topo = f2.maker.fgraph.toposort() topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_2(self):
x = theano.tensor.matrix('x')
r = x.reshape([Shape_i(i)(x) for i in xrange(x.ndim)])
m0 = theano.compile.get_default_mode()
m1 = m0.including('local_useless_reshape')
f1 = theano.function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
class Test_local_reshape_to_dimshuffle(unittest.TestCase): class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论