提交 2ed8808a authored 作者: Ziye Fan's avatar Ziye Fan

recursive local_dimshuffle_lift

上级 2425cd11
......@@ -440,6 +440,17 @@ def local_0_dot_x(node):
######################
def apply_local_dimshuffle_lift(var):
# return var
# lift recursively
if not var.owner:
return var
new = local_dimshuffle_lift.transform(var.owner)
if new:
return new[0]
return var
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node):
"""
......@@ -449,6 +460,7 @@ def local_dimshuffle_lift(node):
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
......@@ -461,23 +473,27 @@ def local_dimshuffle_lift(node):
inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set.
ret = inode.op(*[op.__class__(inp.type.broadcastable,
op.new_order,
op.inplace)(inp) for inp in
inode.inputs], **dict(return_list=True))
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable,
op.new_order,
op.inplace)(inp)
new_inputs.append(apply_local_dimshuffle_lift(new_inp))
ret = inode.op(*new_inputs, **dict(return_list=True))
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
# remove useless dimshuffle
if new_order == range(len(new_order)) and (len(new_order) ==
iinput.type.ndim):
return [iinput]
else:
ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True))
return ret
inplace)(iinput)
return [apply_local_dimshuffle_lift(ret)]
@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论