提交 21c367f7 authored 作者: khaotik's avatar khaotik

fix OpFromGraph visualization when inline

上级 cac2b4bf
...@@ -228,6 +228,7 @@ class PyDotFormatter(object): ...@@ -228,6 +228,7 @@ class PyDotFormatter(object):
label=vparams['dtype'])) label=vparams['dtype']))
# Create sub-graph for OpFromGraph nodes # Create sub-graph for OpFromGraph nodes
# FIXME:
if isinstance(node.op, builders.OpFromGraph): if isinstance(node.op, builders.OpFromGraph):
subgraph = pd.Cluster(__node_id) subgraph = pd.Cluster(__node_id)
gf = PyDotFormatter() gf = PyDotFormatter()
...@@ -244,15 +245,14 @@ class PyDotFormatter(object): ...@@ -244,15 +245,14 @@ class PyDotFormatter(object):
# Inputs mapping # Inputs mapping
ext_inputs = [self.__node_id(x) for x in node.inputs] ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) int_inputs = [gf.__node_id(x)
for x in node.op.fn.maker.fgraph.inputs] for x in node.op.local_inputs]
assert len(ext_inputs) == len(int_inputs) assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs)) h = format_map(zip(ext_inputs, int_inputs))
pd_node.get_attributes()['subg_map_inputs'] = h pd_node.get_attributes()['subg_map_inputs'] = h
# Outputs mapping # Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs] ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = node.op.fn.maker.fgraph.outputs int_outputs = [gf.__node_id(x) for x in node.op.local_outputs]
int_outputs = [gf.__node_id(x) for x in int_outputs]
assert len(ext_outputs) == len(int_outputs) assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs)) h = format_map(zip(int_outputs, ext_outputs))
pd_node.get_attributes()['subg_map_outputs'] = h pd_node.get_attributes()['subg_map_outputs'] = h
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论