提交 e3569d12 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4062 from mohammadpz/tile_test

test using theano tile added
......@@ -510,6 +510,28 @@ def apply_local_dimshuffle_lift(var):
return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input):
is_useless = True
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (new_order[i] == i or
(i in all_broadcastable_dims and
new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
else:
is_useless = False
return is_useless
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node):
"""
......@@ -531,6 +553,7 @@ def local_dimshuffle_lift(node):
input = node.inputs[0]
inode = input.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set.
new_inputs = []
......@@ -544,20 +567,18 @@ def local_dimshuffle_lift(node):
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
# remove useless dimshuffle
if (new_order == list(range(len(new_order))) and
len(new_order) == iinput.type.ndim):
return [iinput]
else:
ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput)
ret = apply_local_dimshuffle_lift(ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
input = inode.inputs[0]
if is_dimshuffle_useless(new_order, input):
return [input]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(input.type.broadcastable, new_order,
inplace)(input)
ret = apply_local_dimshuffle_lift(ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize
......
......@@ -55,6 +55,7 @@ from theano.tensor import (
Subtensor,
TensorType,
Tile,
tile
)
from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt
......@@ -112,14 +113,13 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = FunctionGraph([x], [e])
self.assertTrue(
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
# Check stacktrace was copied over correctly after opt was applied
......@@ -130,9 +130,9 @@ class test_dimshuffle_lift(unittest.TestCase):
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = FunctionGraph([x], [e])
self.assertTrue(
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g))
# Check stacktrace was copied over correctly after opt was applied
......@@ -166,7 +166,6 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_recursive_lift(self):
v = T.vector(dtype="float64")
m = T.matrix(dtype="float64")
......@@ -179,8 +178,8 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
......@@ -189,10 +188,35 @@ class test_dimshuffle_lift(unittest.TestCase):
"Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(new_g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle(self):
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
self.assertTrue(str(g) == "[DimShuffle{0,1}(x)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_dimshuffle_on_broadcastable(self):
x, y, z = inputs([False, True], [True, False, True], [False, False, True])
u = tensor.constant(1)
ds_x = ds(x, (0, 'x')) # useless
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # usefull
ds_u = ds(u, ('x')) # usefull
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
self.assertTrue(str(g) == "[DimShuffle{0,x}(x), DimShuffle{2,1,0}(y), DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x, y, DimShuffle{2,1,0}(z), DimShuffle{x}(TensorConstant{1})]")
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_add_canonizer_problem0():
n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论