提交 9df89a90 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

useless_dimshuffle optimization and test

上级 6c38df98
......@@ -594,7 +594,7 @@ def local_dimshuffle_lift(node):
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
# remove useless dimshuffle
# remove useless dimshuffle caused by merging
if (new_order == list(range(len(new_order))) and
len(new_order) == iinput.type.ndim):
return [iinput]
......@@ -605,6 +605,11 @@ def local_dimshuffle_lift(node):
copy_stack_trace(node.outputs[0], ret)
return [ret]
# remove useless dimshuffle in general
if (list(op.new_order) == list(range(len(op.new_order))) and
len(op.new_order) == input.type.ndim):
return [input]
@register_canonicalize
@gof.local_optimizer([T.DimShuffle])
......
......@@ -112,14 +112,13 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = FunctionGraph([x], [e])
self.assertTrue(
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
# Check stacktrace was copied over correctly after opt was applied
......@@ -130,9 +129,9 @@ class test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = FunctionGraph([x], [e])
self.assertTrue(
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g))
# Check stacktrace was copied over correctly after opt was applied
......@@ -166,7 +165,6 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_recursive_lift(self):
v = T.vector(dtype="float64")
m = T.matrix(dtype="float64")
......@@ -179,8 +177,8 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
......@@ -189,10 +187,21 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
self.assertTrue(str(g) == "[DimShuffle{0,1}(x)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0():
n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论