提交 d6b3b7a6 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Used op.new_order for Alloc(DimShuffle()) and implemented Reshape(DimShuffle()) case

上级 0dec0a90
...@@ -104,20 +104,48 @@ def local_max_to_min(node): ...@@ -104,20 +104,48 @@ def local_max_to_min(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.alloc]) @gof.local_optimizer([T.alloc])
def local_alloc_dimshuffle(node): def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
left, remove it.
"""
if node.op == T.alloc: if node.op == T.alloc:
input_ = node.inputs[0] input_ = node.inputs[0]
if getattr(input_, 'owner', None) and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
pattern = input_.type.broadcastable new_order = input_.owner.op.new_order
if not pattern[0]: flag = False
return False for i, dim in enumerate(new_order_bool):
j = 0 if i == 0 and dim == 'x':
for i, bool_ in enumerate(pattern): flag = True
if not bool_: elif dim == 'x' and flag:
j = i continue
break elif i > 0 and flag:
if sum(pattern[j:]) == 0: flag = False
return input_.inputs elif i > 0 and not dim == 'x':
else : continue
return False else:
return False
return input_.inputs
return False
@register_uncanonicalize
@gof.local_optimizer([T.reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
of dimensions, remove it.
"""
if node.op == T.reshape:
input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_owner.op.new_order
offset = 0
for i, dim in enumerate(new_order):
if dim == 'x':
offset += 1
continue
elif i != dim + offset:
return False
return input_.inputs
return False return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论