提交 40f8ae2e authored 作者: Christof Angermueller's avatar Christof Angermueller

Add node_colors argument to pydotprint

上级 470af8e0
...@@ -13,6 +13,20 @@ def replace_patterns(x, replace): ...@@ -13,6 +13,20 @@ def replace_patterns(x, replace):
return x return x
def d3dot(fct, node_colors=None, *args, **kwargs):
if node_colors is None:
node_colors = {'input': 'limegreen',
'output': 'dodgerblue',
'unused': 'lightgrey'
}
dot_graph = pydotprint(fct, format='dot', return_image=True,
node_colors=node_colors, *args, **kwargs)
dot_graph = dot_graph.replace('\n', ' ')
dot_graph = dot_graph.replace('node [label="\N"];', '')
return dot_graph
def d3print(fct, outfile=None, return_html=False, print_message=True, def d3print(fct, outfile=None, return_html=False, print_message=True,
width=800, height=600, width=800, height=600,
*args, **kwargs): *args, **kwargs):
...@@ -27,13 +41,11 @@ def d3print(fct, outfile=None, return_html=False, print_message=True, ...@@ -27,13 +41,11 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
""" """
# Generate dot graph definition by calling pydotprint # Generate dot graph definition by calling pydotprint
dot_graph = pydotprint(fct, format='dot', return_image=True, *args, **kwargs) dot_graph = d3dot(fct, *args, **kwargs)
dot_graph = dot_graph.replace('\n', ' ')
dot_graph = dot_graph.replace('node [label="\N"];', '')
# Read template HTML file and replace variables # Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'template.html') 'template.html')
f = open(template_file) f = open(template_file)
template = f.read() template = f.read()
f.close() f.close()
......
...@@ -555,6 +555,11 @@ default_colorCodes = {'GpuFromHost': 'red', ...@@ -555,6 +555,11 @@ default_colorCodes = {'GpuFromHost': 'red',
'Subtensor': '#FFAAFF', # purple 'Subtensor': '#FFAAFF', # purple
'Alloc': '#FFAA22'} # orange 'Alloc': '#FFAA22'} # orange
default_node_colors = {'input': 'green',
'output': 'blue',
'unused': 'grey'
}
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False, compact=True, format='png', with_ids=False,
...@@ -564,6 +569,7 @@ def pydotprint(fct, outfile=None, ...@@ -564,6 +569,7 @@ def pydotprint(fct, outfile=None,
print_output_file=True, print_output_file=True,
assert_nb_all_strings=-1, assert_nb_all_strings=-1,
return_image=False, return_image=False,
node_colors=None
): ):
"""Print to a file the graph of a compiled theano function's ops. Supports """Print to a file the graph of a compiled theano function's ops. Supports
all pydot output formats, including png and svg. all pydot output formats, including png and svg.
...@@ -639,6 +645,9 @@ def pydotprint(fct, outfile=None, ...@@ -639,6 +645,9 @@ def pydotprint(fct, outfile=None,
if colorCodes is None: if colorCodes is None:
colorCodes = default_colorCodes colorCodes = default_colorCodes
if node_colors is None:
node_colors = default_node_colors
if outfile is None: if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' + outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format) config.device + '.' + format)
...@@ -851,7 +860,7 @@ def pydotprint(fct, outfile=None, ...@@ -851,7 +860,7 @@ def pydotprint(fct, outfile=None,
param = {} param = {}
if hasattr(node.op, 'view_map') and id in reduce( if hasattr(node.op, 'view_map') and id in reduce(
list.__add__, node.op.view_map.values(), []): list.__add__, node.op.view_map.values(), []):
param['color'] = 'blue' param['color'] = node_colors['output']
elif hasattr(node.op, 'destroy_map') and id in reduce( elif hasattr(node.op, 'destroy_map') and id in reduce(
list.__add__, node.op.destroy_map.values(), []): list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red' param['color'] = 'red'
...@@ -859,10 +868,11 @@ def pydotprint(fct, outfile=None, ...@@ -859,10 +868,11 @@ def pydotprint(fct, outfile=None,
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, g.add_node(pd.Node(varstr,
style='filled', style='filled',
fillcolor='green', fillcolor=node_colors['input'],
shape=var_shape)) shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape)) g.add_node(pd.Node(varstr, color=node_colors['input'],
shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, label=label, **param))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, label=label, **param))
...@@ -883,16 +893,20 @@ def pydotprint(fct, outfile=None, ...@@ -883,16 +893,20 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varstr, style='filled',
fillcolor='blue', shape=var_shape)) fillcolor=node_colors['output'],
shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape)) g.add_node(pd.Node(varstr, color=node_colors['output'],
shape=var_shape))
elif len(var.clients) == 0: elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varstr, style='filled',
fillcolor='grey', shape=var_shape)) fillcolor=node_colors['unused'],
shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape)) g.add_node(pd.Node(varstr, color=node_colors['unused'],
shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论