提交 5db997f7 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

two different optimizations on dimshuffle is now merged to one.

上级 b62fe58d
...@@ -556,6 +556,26 @@ def apply_local_dimshuffle_lift(var): ...@@ -556,6 +556,26 @@ def apply_local_dimshuffle_lift(var):
return var return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input):
is_useless = False
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (new_order[i] == i or
(i in all_broadcastable_dims and
new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
return is_useless
@gof.local_optimizer([DimShuffle]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node): def local_dimshuffle_lift(node):
""" """
...@@ -577,6 +597,7 @@ def local_dimshuffle_lift(node): ...@@ -577,6 +597,7 @@ def local_dimshuffle_lift(node):
input = node.inputs[0] input = node.inputs[0]
inode = input.owner inode = input.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
new_inputs = [] new_inputs = []
...@@ -590,40 +611,19 @@ def local_dimshuffle_lift(node): ...@@ -590,40 +611,19 @@ def local_dimshuffle_lift(node):
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order] new_order]
inplace = op.inplace and inode.op.inplace inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0] input = inode.inputs[0]
# remove useless dimshuffle caused by merging if is_dimshuffle_useless(new_order, input):
if (new_order == list(range(len(new_order))) and return [input]
len(new_order) == iinput.type.ndim): elif inode and isinstance(inode.op, DimShuffle):
return [iinput] ret = op.__class__(input.type.broadcastable, new_order,
else: inplace)(input)
ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput)
ret = apply_local_dimshuffle_lift(ret) ret = apply_local_dimshuffle_lift(ret)
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
# remove useless dimshuffle in general
# covers two types of useless dimshuffle:
# 1 - dimshuffle all dimensions in order
# 2 - dimshuffle a broadcastable dimension
if len(op.new_order) == input.type.ndim:
is_useless = False
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (op.new_order[i] == i or
(i in all_broadcastable_dims and
op.new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
if is_useless:
return [input]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论