提交 c571066f authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Fixed condition on local_alloc_dimshuffle

上级 4e1e90ea
......@@ -102,7 +102,7 @@ def local_max_to_min(node):
@register_uncanonicalize
@gof.local_optimizer([T.alloc])
@gof.local_optimizer([T.Alloc])
def local_alloc_dimshuffle(node):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
......@@ -113,24 +113,16 @@ def local_alloc_dimshuffle(node):
if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
new_order = input_.owner.op.new_order
flag = False
for i, dim in enumerate(new_order_bool):
if i == 0 and dim == 'x':
flag = True
elif dim == 'x' and flag:
continue
elif i > 0 and flag:
flag = False
elif i > 0 and not dim == 'x':
continue
else:
return False
expected_new_order = ('x',) * (input_.ndim - input_.owner.inputs[0].ndim) + \
tuple(range(input_.owner.inputs.ndim))
if new_order != expected_new_order:
return False
return input_.owner.inputs
return False
@register_uncanonicalize
@gof.local_optimizer([T.reshape])
@gof.local_optimizer([T.Reshape])
def local_reshape_dimshuffle(node):
"""
If a dimshuffle is inside a reshape and does not change the order
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论