提交 faeef95f authored 作者: Ziye Fan's avatar Ziye Fan

make local_fill_sink call itself recursively init commit

上级 9bdef90d
......@@ -3576,7 +3576,9 @@ def local_fill_sink(node):
f need to be an elemwise
"""
if not isinstance(node.op, T.Elemwise) or node.op == T.fill:
if (not hasattr(node, 'op') or
not isinstance(node.op, T.Elemwise) or
node.op == T.fill):
return False
models = []
inputs = []
......@@ -3591,6 +3593,18 @@ def local_fill_sink(node):
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
for client, _ in node.outputs[0].clients:
if (hasattr(client, 'op') and
isinstance(client.op, T.Elemwise) and
not client.op == T.fill):
node_sub_c = client.op([v for v in client.inputs if v is not c])
# import ipdb; ipdb.set_trace()
r = local_fill_sink.transform(node_sub_c)
if r:
return {client: r}
return [c]
register_canonicalize(local_fill_sink)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论