提交 4b574d29 authored 作者: Ziye Fan's avatar Ziye Fan

make fill_sink be applied on all its clients

上级 ac62b206
......@@ -3596,18 +3596,18 @@ def local_fill_sink(node):
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
replacements = {node.outputs[0]: c}
for client, _ in node.outputs[0].clients:
if (hasattr(client, 'op') and
isinstance(client.op, T.Elemwise) and
not client.op == T.fill):
# node_sub_c = client.op([v for v in client.inputs if v is not c])
# import ipdb; ipdb.set_trace()
r = local_fill_sink.transform(client)
if isinstance(r, list):
return {client.outputs[0]: r[0]}
elif isinstance(r, dict):
return r
return [c]
if r is False:
continue
replacements.pop(node.outputs[0], None)
replacements.update(r)
return replacements
register_canonicalize(local_fill_sink)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论