提交 e1c2bd9e authored 作者: Frederic's avatar Frederic

small speed up local_fill_sink and make the code more clear

上级 152e95ac
...@@ -2611,8 +2611,10 @@ register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy') ...@@ -2611,8 +2611,10 @@ register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
def local_fill_sink(node): def local_fill_sink(node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
""" """
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill): if not isinstance(node.op, T.Elemwise) or node.op == T.fill:
return False return False
models = [] models = []
inputs = [] inputs = []
...@@ -2622,7 +2624,7 @@ def local_fill_sink(node): ...@@ -2622,7 +2624,7 @@ def local_fill_sink(node):
inputs.append(input.owner.inputs[1]) inputs.append(input.owner.inputs[1])
else: else:
inputs.append(input) inputs.append(input)
if inputs == node.inputs: if not models:
return False return False
c = node.op(*inputs) c = node.op(*inputs)
for model in models: for model in models:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论