提交 9df89a90 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

useless_dimshuffle optimization and test

上级 6c38df98
...@@ -594,7 +594,7 @@ def local_dimshuffle_lift(node): ...@@ -594,7 +594,7 @@ def local_dimshuffle_lift(node):
inplace = op.inplace and inode.op.inplace inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0] iinput = inode.inputs[0]
# remove useless dimshuffle # remove useless dimshuffle caused by merging
if (new_order == list(range(len(new_order))) and if (new_order == list(range(len(new_order))) and
len(new_order) == iinput.type.ndim): len(new_order) == iinput.type.ndim):
return [iinput] return [iinput]
...@@ -605,6 +605,11 @@ def local_dimshuffle_lift(node): ...@@ -605,6 +605,11 @@ def local_dimshuffle_lift(node):
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
# remove useless dimshuffle in general
if (list(op.new_order) == list(range(len(op.new_order))) and
len(op.new_order) == input.type.ndim):
return [input]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.DimShuffle]) @gof.local_optimizer([T.DimShuffle])
......
...@@ -112,7 +112,6 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -112,7 +112,6 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_merge2(self): def test_merge2(self):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
...@@ -166,7 +165,6 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -166,7 +165,6 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_recursive_lift(self): def test_recursive_lift(self):
v = T.vector(dtype="float64") v = T.vector(dtype="float64")
m = T.matrix(dtype="float64") m = T.matrix(dtype="float64")
...@@ -179,8 +177,8 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -179,8 +177,8 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}" "Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, " "(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]") "DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0] new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out]) new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
...@@ -189,10 +187,21 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -189,10 +187,21 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}(DimShuffle{1,0}" "Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), " "(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]") "DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g) self.assertTrue(str(new_g) == opt_str_g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
self.assertTrue(str(g) == "[DimShuffle{0,1}(x)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
n_segments = 10 n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论