提交 5db997f7 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

two different optimizations on dimshuffle is now merged to one.

上级 b62fe58d
......@@ -556,6 +556,26 @@ def apply_local_dimshuffle_lift(var):
return var
# Checks for two types of useless dimshuffles:
# 1 - dimshuffle all dimensions in order.
# 2 - dimshuffle a broadcastable dimension.
def is_dimshuffle_useless(new_order, input):
is_useless = False
if len(new_order) == input.type.ndim:
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (new_order[i] == i or
(i in all_broadcastable_dims and
new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
return is_useless
@gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node):
"""
......@@ -577,6 +597,7 @@ def local_dimshuffle_lift(node):
input = node.inputs[0]
inode = input.owner
new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set.
new_inputs = []
......@@ -590,40 +611,19 @@ def local_dimshuffle_lift(node):
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in
op.new_order]
new_order]
inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0]
# remove useless dimshuffle caused by merging
if (new_order == list(range(len(new_order))) and
len(new_order) == iinput.type.ndim):
return [iinput]
else:
ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput)
ret = apply_local_dimshuffle_lift(ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
input = inode.inputs[0]
if is_dimshuffle_useless(new_order, input):
return [input]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(input.type.broadcastable, new_order,
inplace)(input)
ret = apply_local_dimshuffle_lift(ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
# remove useless dimshuffle in general
# covers two types of useless dimshuffle:
# 1 - dimshuffle all dimensions in order
# 2 - dimshuffle a broadcastable dimension
if len(op.new_order) == input.type.ndim:
is_useless = False
all_broadcastable_dims = [i for (i, is_broadcastable)
in enumerate(input.type.broadcastable)
if is_broadcastable] + ['x']
for i in range(input.type.ndim):
if (op.new_order[i] == i or
(i in all_broadcastable_dims and
op.new_order[i] in all_broadcastable_dims)):
is_useless = True
else:
is_useless = False
break
if is_useless:
return [input]
@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论