提交 726cce89 authored 作者: Christof Angermueller's avatar Christof Angermueller

Update d3print to use GraphFormatter

上级 45363f3f
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import os.path import os.path
from theano.printing import pydotprint from theano.printing import pydotprint
from formatting import GraphFormatter
def replace_patterns(x, replace): def replace_patterns(x, replace):
...@@ -29,13 +31,13 @@ def d3dot(fct, node_colors=None, *args, **kwargs): ...@@ -29,13 +31,13 @@ def d3dot(fct, node_colors=None, *args, **kwargs):
def d3write(fct, path, *args, **kwargs): def d3write(fct, path, *args, **kwargs):
dot = d3dot(fct, *args, **kwargs) # Convert theano graph to pydot graph and write to file
with open(path, 'w') as f: gf = GraphFormatter(*args, **kwargs)
f.write(dot) g = gf.to_pydot(fct)
g.write_dot(path)
def d3print(fct, outfile=None, return_html=False, print_message=True, def d3print(fct, outfile=None, return_html=False, print_message=True,
width=800, height=600,
*args, **kwargs): *args, **kwargs):
"""Creates dynamic graph visualization using d3.js javascript library. """Creates dynamic graph visualization using d3.js javascript library.
...@@ -47,11 +49,10 @@ def d3print(fct, outfile=None, return_html=False, print_message=True, ...@@ -47,11 +49,10 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
:param *args, **kwargs: Parameters passed to pydotprint :param *args, **kwargs: Parameters passed to pydotprint
""" """
# Generate dot graph by pydotprint and write to file # Create dot file
dot_graph = d3dot(fct, *args, **kwargs)
dot_file = os.path.splitext(outfile)[0] + '.dot' dot_file = os.path.splitext(outfile)[0] + '.dot'
with open(dot_file, 'w') as f: d3write(fct, dot_file, *args, **kwargs)
f.write(dot_graph)
# Read template HTML file and replace variables # Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
......
...@@ -59,9 +59,9 @@ class GraphFormatter(object): ...@@ -59,9 +59,9 @@ class GraphFormatter(object):
'Elemwise': '#FFAABB', # dark pink 'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple 'Subtensor': '#FFAAFF', # purple
'Alloc': '#FFAA22'} # orange 'Alloc': '#FFAA22'} # orange
self.node_colors = {'input': 'green', self.node_colors = {'input': 'limegreen',
'output': 'blue', 'output': 'dodgerblue',
'unused': 'grey' 'unused': 'lightgrey'
} }
self.max_label_size = 70 self.max_label_size = 70
...@@ -168,7 +168,7 @@ class GraphFormatter(object): ...@@ -168,7 +168,7 @@ class GraphFormatter(object):
nw_node = pd.Node(astr, shape=apply_shape) nw_node = pd.Node(astr, shape=apply_shape)
elif self.high_contrast: elif self.high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color, nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape=apply_shape) shape=apply_shape, type='colored')
else: else:
nw_node = pd.Node(astr, color=use_color, shape=apply_shape) nw_node = pd.Node(astr, color=use_color, shape=apply_shape)
g.add_node(nw_node) g.add_node(nw_node)
......
...@@ -146,6 +146,7 @@ ...@@ -146,6 +146,7 @@
var posMax = [0, 0]; var posMax = [0, 0];
for (var nodeId in dotGraph._nodes) { for (var nodeId in dotGraph._nodes) {
var node = dotGraph._nodes[nodeId]; var node = dotGraph._nodes[nodeId];
node.value.label = node.id;
node.value.pos = node.value.pos.split(',').map(function(d) {return parseInt(d);}); node.value.pos = node.value.pos.split(',').map(function(d) {return parseInt(d);});
node.value.width = parseInt(node.value.width); node.value.width = parseInt(node.value.width);
node.value.height = parseInt(node.value.height); node.value.height = parseInt(node.value.height);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论