提交 4410fb0f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3014 from t13m/optimize_local_dimshuffle_list

Make local_dimshuffle_lift go through many nodes
......@@ -440,6 +440,17 @@ def local_0_dot_x(node):
######################
def apply_local_dimshuffle_lift(var):
# return var
# lift recursively
if not var.owner:
return var
new = local_dimshuffle_lift.transform(var.owner)
if new:
return new[0]
return var
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node):
"""
......@@ -449,6 +460,7 @@ def local_dimshuffle_lift(node):
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
......@@ -461,23 +473,27 @@ def local_dimshuffle_lift(node):
inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set.
ret = inode.op(*[op.__class__(inp.type.broadcastable,
op.new_order,
op.inplace)(inp) for inp in
inode.inputs], **dict(return_list=True))
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable,
op.new_order,
op.inplace)(inp)
new_inputs.append(apply_local_dimshuffle_lift(new_inp))
ret = inode.op(*new_inputs, **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
# remove useless dimshuffle
if new_order == range(len(new_order)) and (len(new_order) ==
iinput.type.ndim):
return [iinput]
else:
ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True))
return ret
inplace)(iinput)
return [apply_local_dimshuffle_lift(ret)]
@register_canonicalize
......
......@@ -153,6 +153,29 @@ class test_dimshuffle_lift(unittest.TestCase):
self.assertTrue(str(g) in (opt_str_g_inplace, opt_str_g_noinplace),
str(g))
def test_recursive_lift(self):
v = T.vector(dtype="float64")
m = T.matrix(dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = ("[DimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(DimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, vector)>, "
"DimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"DimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{0,x}(<TensorType(float64, vector)>), "
"DimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(DimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"DimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g)
def test_add_canonizer_problem0():
n_segments = 10
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论