提交 e3569d12 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4062 from mohammadpz/tile_test

test using theano tile added
...@@ -510,6 +510,28 @@ def apply_local_dimshuffle_lift(var): ...@@ -510,6 +510,28 @@ def apply_local_dimshuffle_lift(var):
return var return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input):
is_useless = True
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (new_order[i] == i or
(i in all_broadcastable_dims and
new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
else:
is_useless = False
return is_useless
@gof.local_optimizer([DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node): def local_dimshuffle_lift(node):
""" """
...@@ -531,6 +553,7 @@ def local_dimshuffle_lift(node): ...@@ -531,6 +553,7 @@ def local_dimshuffle_lift(node):
input = node.inputs[0] input = node.inputs[0]
inode = input.owner inode = input.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
new_inputs = [] new_inputs = []
...@@ -544,20 +567,18 @@ def local_dimshuffle_lift(node): ...@@ -544,20 +567,18 @@ def local_dimshuffle_lift(node):
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order] new_order]
inplace = op.inplace and inode.op.inplace inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0] input = inode.inputs[0]
# remove useless dimshuffle if is_dimshuffle_useless(new_order, input):
if (new_order == list(range(len(new_order))) and return [input]
len(new_order) == iinput.type.ndim): elif inode and isinstance(inode.op, DimShuffle):
return [iinput] ret = op.__class__(input.type.broadcastable, new_order,
else: inplace)(input)
ret = op.__class__(iinput.type.broadcastable, new_order, ret = apply_local_dimshuffle_lift(ret)
inplace)(iinput) copy_stack_trace(node.outputs[0], ret)
ret = apply_local_dimshuffle_lift(ret) return [ret]
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize @register_canonicalize
......
...@@ -55,6 +55,7 @@ from theano.tensor import ( ...@@ -55,6 +55,7 @@ from theano.tensor import (
Subtensor, Subtensor,
TensorType, TensorType,
Tile, Tile,
tile
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -112,14 +113,13 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -112,14 +113,13 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_merge2(self): def test_merge2(self):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e])
self.assertTrue( self.assertTrue(
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g)) str(g))
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g)) self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
...@@ -130,9 +130,9 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -130,9 +130,9 @@ class test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e])
self.assertTrue( self.assertTrue(
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}" str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]", "(DimShuffle{0,x,1}(x)))]",
str(g)) str(g))
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g)) self.assertTrue(str(g) == "[x]", str(g))
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
...@@ -166,7 +166,6 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -166,7 +166,6 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_recursive_lift(self): def test_recursive_lift(self):
v = T.vector(dtype="float64") v = T.vector(dtype="float64")
m = T.matrix(dtype="float64") m = T.matrix(dtype="float64")
...@@ -179,8 +178,8 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -179,8 +178,8 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}" "Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, " "(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]") "DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g) self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0] new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out]) new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
...@@ -189,10 +188,35 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -189,10 +188,35 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}(DimShuffle{1,0}" "Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), " "(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]") "DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g) self.assertTrue(str(new_g) == opt_str_g)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
self.assertTrue(str(g) == "[DimShuffle{0,1}(x)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
u = tensor.constant(1)
ds_x = ds(x, (0, 'x')) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # usefull
ds_u = ds(u, ('x')) # usefull
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
self.assertTrue(str(g) == "[DimShuffle{0,x}(x), DimShuffle{2,1,0}(y), DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x, y, DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
n_segments = 10 n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论