提交 d6b3b7a6 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Used op.new_order for Alloc(DimShuffle()) and implemented Reshape(DimShuffle()) case

上级 0dec0a90
......@@ -104,20 +104,48 @@ def local_max_to_min(node):
@register_uncanonicalize
@gof.local_optimizer([T.alloc])
def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
left, remove it.
"""
if node.op == T.alloc:
input_ = node.inputs[0]
if getattr(input_, 'owner', None) and isinstance(input_.owner.op, DimShuffle):
if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
pattern = input_.type.broadcastable
if not pattern[0]:
return False
j = 0
for i, bool_ in enumerate(pattern):
if not bool_:
j = i
break
if sum(pattern[j:]) == 0:
return input_.inputs
else :
return False
new_order = input_.owner.op.new_order
flag = False
for i, dim in enumerate(new_order_bool):
if i == 0 and dim == 'x':
flag = True
elif dim == 'x' and flag:
continue
elif i > 0 and flag:
flag = False
elif i > 0 and not dim == 'x':
continue
else:
return False
return input_.inputs
return False
@register_uncanonicalize
@gof.local_optimizer([T.reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
"""
if node.op == T.reshape:
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_owner.op.new_order
offset = 0
for i, dim in enumerate(new_order):
if dim == 'x':
offset += 1
continue
elif i != dim + offset:
return False
return input_.inputs
return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论