提交 498e7652 authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove type on edge when the variable have it.

上级 b73758ee
...@@ -884,11 +884,9 @@ def pydotprint(fct, outfile=None, ...@@ -884,11 +884,9 @@ def pydotprint(fct, outfile=None,
for id, var in enumerate(node.inputs): for id, var in enumerate(node.inputs):
varstr = var_name(var) varstr = var_name(var)
label = str(var.type) label = ""
if len(node.inputs) > 1: if len(node.inputs) > 1:
label = str(id) + ' ' + label label = str(id)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param = {} param = {}
if hasattr(node.op, 'view_map') and id in reduce( if hasattr(node.op, 'view_map') and id in reduce(
list.__add__, node.op.view_map.values(), []): list.__add__, node.op.view_map.values(), []):
...@@ -896,6 +894,8 @@ def pydotprint(fct, outfile=None, ...@@ -896,6 +894,8 @@ def pydotprint(fct, outfile=None,
elif hasattr(node.op, 'destroy_map') and id in reduce( elif hasattr(node.op, 'destroy_map') and id in reduce(
list.__add__, node.op.destroy_map.values(), []): list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red' param['color'] = 'red'
if label:
param['label'] = label
if var.owner is None: if var.owner is None:
color = 'green' color = 'green'
if isinstance(var, theano.compile.SharedVariable): if isinstance(var, theano.compile.SharedVariable):
...@@ -909,13 +909,18 @@ def pydotprint(fct, outfile=None, ...@@ -909,13 +909,18 @@ def pydotprint(fct, outfile=None,
shape=var_shape)) shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color=color, shape=var_shape)) g.add_node(pd.Node(varstr, color=color, shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, **param))
elif var.name or not compact or var in outputs: elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, **param))
else: else:
# no name, so we don't make a var ellipse # no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr, if label:
label=label, **param)) label += " "
label += str(var.type)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(apply_name(var.owner), astr, **param))
for id, var in enumerate(node.outputs): for id, var in enumerate(node.outputs):
varstr = var_name(var) varstr = var_name(var)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论