提交 e1c2bd9e authored 作者: Frederic's avatar Frederic

small speed up local_fill_sink and make the code more clear

上级 152e95ac
......@@ -2611,8 +2611,10 @@ register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
"""
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill):
if not isinstance(node.op, T.Elemwise) or node.op == T.fill:
return False
models = []
inputs = []
......@@ -2622,7 +2624,7 @@ def local_fill_sink(node):
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if inputs == node.inputs:
if not models:
return False
c = node.op(*inputs)
for model in models:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论