提交 c571066f authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Fixed condition on local_alloc_dimshuffle

上级 4e1e90ea
...@@ -102,7 +102,7 @@ def local_max_to_min(node): ...@@ -102,7 +102,7 @@ def local_max_to_min(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.alloc]) @gof.local_optimizer([T.Alloc])
def local_alloc_dimshuffle(node): def local_alloc_dimshuffle(node):
""" """
If a dimshuffle is inside an alloc and only adds dimension to the If a dimshuffle is inside an alloc and only adds dimension to the
...@@ -113,24 +113,16 @@ def local_alloc_dimshuffle(node): ...@@ -113,24 +113,16 @@ def local_alloc_dimshuffle(node):
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
flag = False expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
for i, dim in enumerate(new_order_bool): tuple(range(input_.owner.inputs.ndim))
if i == 0 and dim == 'x': if new_order != expected_new_order:
flag = True
elif dim == 'x' and flag:
continue
elif i > 0 and flag:
flag = False
elif i > 0 and not dim == 'x':
continue
else:
return False return False
return input_.owner.inputs return input_.owner.inputs
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.reshape]) @gof.local_optimizer([T.Reshape])
def local_reshape_dimshuffle(node): def local_reshape_dimshuffle(node):
""" """
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论