提交 b62fe58d authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

length check added

上级 835d8444
...@@ -609,20 +609,21 @@ def local_dimshuffle_lift(node): ...@@ -609,20 +609,21 @@ def local_dimshuffle_lift(node):
# covers two types of useless dimshuffle: # covers two types of useless dimshuffle:
# 1 - dimshuffle all dimensions in order # 1 - dimshuffle all dimensions in order
# 2 - dimshuffle a broadcastable dimension # 2 - dimshuffle a broadcastable dimension
is_useless = False if len(op.new_order) == input.type.ndim:
all_broadcastable_dims = [i for (i, is_broadcastable) is_useless = False
in enumerate(input.type.broadcastable) all_broadcastable_dims = [i for (i, is_broadcastable)
if is_broadcastable] + ['x'] in enumerate(input.type.broadcastable)
for i in range(input.type.ndim): if is_broadcastable] + ['x']
if (op.new_order[i] == i or for i in range(input.type.ndim):
(i in all_broadcastable_dims and if (op.new_order[i] == i or
op.new_order[i] in all_broadcastable_dims)): (i in all_broadcastable_dims and
is_useless = True op.new_order[i] in all_broadcastable_dims)):
else: is_useless = True
is_useless = False else:
break is_useless = False
if is_useless: break
return [input] if is_useless:
return [input]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论