提交 6033a3e2 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

useless dimshuffle in reshape is removed

上级 e845b755
......@@ -570,6 +570,25 @@ def local_dimshuffle_lift(node):
"""
op = node.op
if (isinstance(op, T.Reshape) and
node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
new_order = node.inputs[0].owner.op.new_order
new_order = [i for i in new_order if i != 'x']
input = node.inputs[0].owner.inputs[0]
broadcastables = input.broadcastable
new_order_of_nonbroadcastables = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcastables.append(i)
no_change_in_order = all(
new_order_of_nonbroadcastables[i] <= new_order_of_nonbroadcastables[i + 1]
for i in xrange(len(new_order_of_nonbroadcastables) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
if not isinstance(op, DimShuffle):
return False
......
......@@ -220,6 +220,33 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle_in_presence_of_reshape(self):
vector = TensorType(broadcastable=(False,), dtype='float64')('vector')
mat = TensorType(broadcastable=(False, False), dtype='float64')('mat')
row = TensorType(broadcastable=(True, False), dtype='float64')('row')
col = TensorType(broadcastable=(False, True), dtype='float64')('col')
reshape_dimshuffle_vector = tensor.reshape(vector.dimshuffle('x', 0), vector.shape)
reshape_dimshuffle_mat = tensor.reshape(mat.dimshuffle('x', 0, 'x', 1), mat.shape)
reshape_dimshuffle_row = tensor.reshape(row.dimshuffle(1, 'x'), row.shape)
reshape_dimshuffle_col = tensor.reshape(col.dimshuffle(0), col.shape)
g = FunctionGraph([vector, mat, row, col],
[reshape_dimshuffle_vector, reshape_dimshuffle_mat,
reshape_dimshuffle_row, reshape_dimshuffle_col])
self.assertTrue(str(g) == "[Reshape{1}(DimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(DimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(DimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(DimShuffle{0}(col), Shape(col))]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0():
n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论