提交 519e007d authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

First point on Issue #4647

上级 dab0b393
...@@ -98,3 +98,25 @@ def local_max_to_min(node): ...@@ -98,3 +98,25 @@ def local_max_to_min(node):
max.owner.op.axis)(neg.owner.inputs[0])] max.owner.op.axis)(neg.owner.inputs[0])]
return False return False
@register_uncanonicalize
@gof.local_optimizer([T.alloc])
def local_alloc_dimshuffle(node):
if node.op == T.alloc:
input_ = node.inputs[0]
if isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left
pattern = input_.type.broadcastable
if not pattern[0] :
return False
j = 0
for i, bool_ in enumerate(pattern) :
if not bool_ :
j = i
break
if sum(pattern[j:] == 0 :
return input_.inputs
else :
return False
return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论