提交 2ed8808a authored 作者: Ziye Fan's avatar Ziye Fan

recursive local_dimshuffle_lift

上级 2425cd11
...@@ -440,6 +440,17 @@ def local_0_dot_x(node): ...@@ -440,6 +440,17 @@ def local_0_dot_x(node):
###################### ######################
def apply_local_dimshuffle_lift(var):
# return var
# lift recursively
if not var.owner:
return var
new = local_dimshuffle_lift.transform(var.owner)
if new:
return new[0]
return var
@gof.local_optimizer([DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node): def local_dimshuffle_lift(node):
""" """
...@@ -449,6 +460,7 @@ def local_dimshuffle_lift(node): ...@@ -449,6 +460,7 @@ def local_dimshuffle_lift(node):
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y)) DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x) DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are After this transform, clusters of Elemwise operations are
void of DimShuffle operations. void of DimShuffle operations.
...@@ -461,23 +473,27 @@ def local_dimshuffle_lift(node): ...@@ -461,23 +473,27 @@ def local_dimshuffle_lift(node):
inode = input.owner inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
ret = inode.op(*[op.__class__(inp.type.broadcastable, new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable,
op.new_order, op.new_order,
op.inplace)(inp) for inp in op.inplace)(inp)
inode.inputs], **dict(return_list=True)) new_inputs.append(apply_local_dimshuffle_lift(new_inp))
ret = inode.op(*new_inputs, **dict(return_list=True))
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order] op.new_order]
inplace = op.inplace and inode.op.inplace inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0] iinput = inode.inputs[0]
# remove useless dimshuffle
if new_order == range(len(new_order)) and (len(new_order) == if new_order == range(len(new_order)) and (len(new_order) ==
iinput.type.ndim): iinput.type.ndim):
return [iinput] return [iinput]
else: else:
ret = op.__class__(iinput.type.broadcastable, new_order, ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True)) inplace)(iinput)
return ret return [apply_local_dimshuffle_lift(ret)]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论