提交 726cce89 authored 作者: Christof Angermueller's avatar Christof Angermueller

Update d3print to use GraphFormatter

上级 45363f3f
......@@ -5,6 +5,8 @@
import os.path
from theano.printing import pydotprint
from formatting import GraphFormatter
def replace_patterns(x, replace):
......@@ -29,13 +31,13 @@ def d3dot(fct, node_colors=None, *args, **kwargs):
def d3write(fct, path, *args, **kwargs):
dot = d3dot(fct, *args, **kwargs)
with open(path, 'w') as f:
f.write(dot)
# Convert theano graph to pydot graph and write to file
gf = GraphFormatter(*args, **kwargs)
g = gf.to_pydot(fct)
g.write_dot(path)
def d3print(fct, outfile=None, return_html=False, print_message=True,
width=800, height=600,
*args, **kwargs):
"""Creates dynamic graph visualization using d3.js javascript library.
......@@ -47,11 +49,10 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
:param *args, **kwargs: Parameters passed to pydotprint
"""
# Generate dot graph by pydotprint and write to file
dot_graph = d3dot(fct, *args, **kwargs)
# Create dot file
dot_file = os.path.splitext(outfile)[0] + '.dot'
with open(dot_file, 'w') as f:
f.write(dot_graph)
d3write(fct, dot_file, *args, **kwargs)
# Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
......
......@@ -59,9 +59,9 @@ class GraphFormatter(object):
'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple
'Alloc': '#FFAA22'} # orange
self.node_colors = {'input': 'green',
'output': 'blue',
'unused': 'grey'
self.node_colors = {'input': 'limegreen',
'output': 'dodgerblue',
'unused': 'lightgrey'
}
self.max_label_size = 70
......@@ -168,7 +168,7 @@ class GraphFormatter(object):
nw_node = pd.Node(astr, shape=apply_shape)
elif self.high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape=apply_shape)
shape=apply_shape, type='colored')
else:
nw_node = pd.Node(astr, color=use_color, shape=apply_shape)
g.add_node(nw_node)
......
......@@ -146,6 +146,7 @@
var posMax = [0, 0];
for (var nodeId in dotGraph._nodes) {
var node = dotGraph._nodes[nodeId];
node.value.label = node.id;
node.value.pos = node.value.pos.split(',').map(function(d) {return parseInt(d);});
node.value.width = parseInt(node.value.width);
node.value.height = parseInt(node.value.height);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论