提交 40f8ae2e authored 作者: Christof Angermueller's avatar Christof Angermueller

Add node_colors argument to pydotprint

上级 470af8e0
......@@ -13,6 +13,20 @@ def replace_patterns(x, replace):
return x
def d3dot(fct, node_colors=None, *args, **kwargs):
if node_colors is None:
node_colors = {'input': 'limegreen',
'output': 'dodgerblue',
'unused': 'lightgrey'
}
dot_graph = pydotprint(fct, format='dot', return_image=True,
node_colors=node_colors, *args, **kwargs)
dot_graph = dot_graph.replace('\n', ' ')
dot_graph = dot_graph.replace('node [label="\N"];', '')
return dot_graph
def d3print(fct, outfile=None, return_html=False, print_message=True,
width=800, height=600,
*args, **kwargs):
......@@ -27,13 +41,11 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
"""
# Generate dot graph definition by calling pydotprint
dot_graph = pydotprint(fct, format='dot', return_image=True, *args, **kwargs)
dot_graph = dot_graph.replace('\n', ' ')
dot_graph = dot_graph.replace('node [label="\N"];', '')
dot_graph = d3dot(fct, *args, **kwargs)
# Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'template.html')
'template.html')
f = open(template_file)
template = f.read()
f.close()
......
......@@ -555,6 +555,11 @@ default_colorCodes = {'GpuFromHost': 'red',
'Subtensor': '#FFAAFF', # purple
'Alloc': '#FFAA22'} # orange
default_node_colors = {'input': 'green',
'output': 'blue',
'unused': 'grey'
}
def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False,
......@@ -564,6 +569,7 @@ def pydotprint(fct, outfile=None,
print_output_file=True,
assert_nb_all_strings=-1,
return_image=False,
node_colors=None
):
"""Print to a file the graph of a compiled theano function's ops. Supports
all pydot output formats, including png and svg.
......@@ -639,6 +645,9 @@ def pydotprint(fct, outfile=None,
if colorCodes is None:
colorCodes = default_colorCodes
if node_colors is None:
node_colors = default_node_colors
if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format)
......@@ -851,7 +860,7 @@ def pydotprint(fct, outfile=None,
param = {}
if hasattr(node.op, 'view_map') and id in reduce(
list.__add__, node.op.view_map.values(), []):
param['color'] = 'blue'
param['color'] = node_colors['output']
elif hasattr(node.op, 'destroy_map') and id in reduce(
list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red'
......@@ -859,10 +868,11 @@ def pydotprint(fct, outfile=None,
if high_contrast:
g.add_node(pd.Node(varstr,
style='filled',
fillcolor='green',
fillcolor=node_colors['input'],
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape))
g.add_node(pd.Node(varstr, color=node_colors['input'],
shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
......@@ -883,16 +893,20 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
fillcolor='blue', shape=var_shape))
fillcolor=node_colors['output'],
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
g.add_node(pd.Node(varstr, color=node_colors['output'],
shape=var_shape))
elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
fillcolor='grey', shape=var_shape))
fillcolor=node_colors['unused'],
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
g.add_node(pd.Node(varstr, color=node_colors['unused'],
shape=var_shape))
elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label))
# else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论