提交 470af8e0 authored 作者: Christof Angermueller's avatar Christof Angermueller

Add d3printing module

上级 1bf7ea39
......@@ -34,4 +34,5 @@ distribute-*.egg
distribute-*.tar.gz
Theano.suo
*.DS_Store
*.bak
\ No newline at end of file
*.bak
.ipynb_checkpoints
<!DOCTYPE html>
<html>
<head lang='en'>
<meta charset="utf-8">
<script type="text/javascript" src="http://d3js.org/d3.v3.min.js"></script>
<script type='text/javascript' src="http://cpettitt.github.io/project/dagre-d3/v0.1.5/dagre-d3.min.js"></script>
<script type='text/javascript' src="http://cpettitt.github.io/project/graphlib-dot/v0.4.10/graphlib-dot.min.js"></script>
</head>
<body>
<style>
.svg {
}
.nodeRect {
stroke: black;
border: 3px solid black;
fill: lightsteelblue;
}
.nodeText {
color: black;
font-family: courier;
}
.edge {
stroke: black;
stroke-width: 3px;
}
.edgeLabelRect {
fill: white;
}
.edgeLabelText {
fill: limegreen;
font-family: courier;
text-anchor: start;
}
.arrowHead {
stroke: black;
stroke-width: 3px;
fill: black;
}
</style>
<script type="text/javascript">
// Global attributes
var dotGraphDef = 'digraph G { graph [bb="0,0,763,388"]; "DimShuffle{x,0}" [height=0.5, pos="600,282", shape=ellipse, width=2.0339]; "Elemwise{add,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="474,194", shape=ellipse, style=filled, width=3.0624]; "DimShuffle{x,0}" -> "Elemwise{add,no_inplace}" [label="1 TensorType(float64, row)", lp="648,238", pos="e,516.94,210.65 587.13,264.21 578.18,253.39 565.47,239.59 552,230 544.1,224.38 535.16,219.35 526.22,214.97"]; "name=b TensorType(float64, vector)" [fillcolor=green, height=0.5, pos="607,370", shape=box, style=filled, width=3.0903]; "name=b TensorType(float64, vector)" -> "DimShuffle{x,0}" [label="TensorType(float64, vector)", lp="684,326", pos="e,601.39,300.08 605.58,351.6 604.62,339.75 603.32,323.82 602.22,310.29"]; dot [height=0.5, pos="362,282", shape=ellipse, width=0.75]; dot -> "Elemwise{add,no_inplace}" [label="0 TensorType(float64, matrix)", lp="467,238", pos="e,414.72,209.23 365,263.89 367.74,252.92 372.81,239.1 382,230 388.76,223.31 396.93,217.87 405.58,213.46"]; "name=X TensorType(float64, matrix)" [fillcolor=green, height=0.5, pos="114,370", shape=box, style=filled, width=3.1667]; "name=X TensorType(float64, matrix)" -> dot [label="0 TensorType(float64, matrix)", lp="273,326", pos="e,335.37,285 134.04,351.84 148.22,340.57 168.18,326.4 188,318 233.06,298.9 289.18,290.03 325.29,286.04"]; "name=W TensorType(float64, matrix)" [fillcolor=green, height=0.5, pos="362,370", shape=box, style=filled, width=3.2014]; "name=W TensorType(float64, matrix)" -> dot [label="1 TensorType(float64, matrix)", lp="447,326", pos="e,362,300.08 362,351.6 362,339.75 362,323.82 362,310.29"]; Softmax [height=0.5, pos="474,106", shape=ellipse, width=1.1472]; "Elemwise{add,no_inplace}" -> Softmax [label="TensorType(float64, matrix)", lp="554,150", pos="e,474,124.08 474,175.6 474,163.75 474,147.82 474,134.29"]; "TensorType(float64, matrix) id=6" [fillcolor=blue, height=0.5, pos="474,18", shape=box, style=filled, width=2.8403]; Softmax -> "TensorType(float64, matrix) id=6" [label="TensorType(float64, matrix)", lp="554,62", pos="e,474,36.084 474,87.597 474,75.746 474,59.817 474,46.292"]; } ';
var width = 800;
var height = 600;
// Add SVG element
var svg = d3.select('body').append('svg').attr('class', 'svg').attr('width', width).attr('height', height);
var pane = svg.append('g').attr('transform', 'scale(0.8)');
// Definition head of edges
svg.append("defs").append("marker")
.attr("id", 'markerEnd')
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("refX", 5)
.attr("refY", 3)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,0 L6,3 L0,6 Z")
.attr('style', 'arrowHead');
function textWidth(text) {
return text.length * 10.5;
}
// Parse dot graph definition
var graph = {};
var nodes = [];
var edges = [];
var dotGraph = graphlibDot.parse(dotGraphDef);
// TODO: parse from file
var dotWidth = 763;
var dotHeight = 388;
var scaleDotX = d3.scale.linear().domain([0, dotWidth]).range([0, width]);
var scaleDotY = d3.scale.linear().domain([0, dotHeight]).range([0, height]);
// Parse nodes
var i = 0;
for (nodeId in dotGraph._nodes) {
var node = dotGraph._nodes[nodeId];
node.index = i++;
node.value.width = textWidth(node.value.label);
node.value.height = 40;
node.value.cx = node.value.width / 2;
node.value.cy = node.value.height / 2;
node.value.pos = node.value.pos.split(',').map(function(d) {return parseInt(d);});
node.x = scaleDotX(node.value.pos[0]);
node.y = scaleDotY(dotHeight - node.value.pos[1]);
node.fixed = false;
nodes.push(node);
dotGraph._nodes[nodeId] = node;
}
// Parse edges
for (edgeId in dotGraph._edges) {
var edge = dotGraph._edges[edgeId];
edge.source = dotGraph._nodes[edge.u].index;
edge.target = dotGraph._nodes[edge.v].index;
edge.value.width = textWidth(edge.value.label);
edge.value.height = 40;
edges.push(edge);
dotGraph._edges[edgeId] = edge;
}
// Setup graph
graph['nodes'] = nodes;
graph['edges'] = edges;
// Add edges
edges = pane.append('g').attr('id', 'edges').selectAll('path').data(graph['edges']).enter().append('path')
.attr('class', 'edge')
.attr('marker-end', 'url(#markerEnd)');
// Add edge labels
edgeLabels = pane.append('g').attr('id', 'edgeLabels').selectAll('g').data(graph['edges']).enter().append('g')
.attr('opacity', 0)
.on('mouseover', function(d) {d3.select(this).attr('opacity', 1.0);})
.on('mouseout', function(d) {d3.select(this).attr('opacity', 0.0);});
var edgeLabelsRect = edgeLabels.append('rect')
.attr('class', 'edgeLabelRect')
.attr('fill', 'white')
.attr('width', function(d) {return d.value.width;})
.attr('height', function(d) {return d.value.height;});
var edgeLabelsText = edgeLabels.append('text')
.attr('class', 'edgeLabelText')
.attr('dy', function(d) {return 0.5 * d.value.height;})
.text(function(d) {return d.value.label;});
// Add nodes
nodes = pane.append('g').attr('id', 'nodes').selectAll('g').data(graph['nodes']).enter().append('g');
var nodesRect = nodes.append('rect')
.attr('class', 'nodeRect')
.attr('width', function(d) {return d.value.width;})
.attr('height', function(d) {return d.value.height;});
var nodesText = nodes.append('text')
.attr('class', 'nodeText')
.attr('x', 5)
.attr('dy', function(d) {return d.value.height - 15;})
.text(function(d) {return d.value.label;});
// Update graph
function update() {
// Update nodes
nodes.attr('transform', function(d) { return 'translate(' + d.x + ' ' + d.y + ')'; });
// Update edges
edges.attr('d', function(d) {
return 'M' + (d.source.x + d.source.value.cx) + ',' + (d.source.y + d.source.value.cy) + ' L' + (d.target.x + d.target.value.cx) + ',' + (d.target.y - 3);
});
// Update edge labels
edgeLabels.attr('transform', function(d) {
return 'translate(' + (0.5 * (d.source.x + d.source.value.cx + d.target.x + d.target.value.cx) - 0.5 * d.value.width) + ',' + (0.5 * (d.source.y + d.source.value.cy + d.target.y + d.target.value.cy) - 0.5 * d.value.height) + ')';
});
}
// Drag-start event handler
function dragStart(d) {
d3.event.sourceEvent.stopPropagation();
d.fixed = true;
}
// Zoom and translate event handler
function zoom(d) {
pane.attr('transform', 'translate(' + d3.event.translate + ') scale(' + d3.event.scale + ')');
}
// Force layout
var layout = d3.layout.force()
.nodes(graph['nodes'])
.links(graph['edges'])
.size([width, height])
.gravity(0.2)
.charge(-6000)
.linkDistance(75)
.linkStrength(0.1)
.on('tick', update);
// Drag behavour
var drag = layout.drag()
.on('dragstart', dragStart);
nodes.call(drag);
// Zoom behaviour
var bZoom = d3.behavior.zoom()
.scaleExtent([0.2, 8])
.on('zoom', zoom);
svg.call(bZoom);
// Start force layout
layout.start();
</script>
</body>
</html>
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"import numpy as np\n",
"import theano\n",
"import theano.tensor as T\n",
"import theano.d3printing as d3p\n",
"from IPython.display import HTML\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Logistic regression model\n",
"num_in = 64**2\n",
"num_out = 10\n",
"W = theano.shared(np.random.randn(num_in, num_out), name='W', borrow=True)\n",
"b = theano.shared(np.random.randn(num_out), name='b', borrow=True)\n",
"X = T.dmatrix('X')\n",
"y = T.nnet.softmax(X.dot(W) + b)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at logreg.html\n"
]
}
],
"source": [
"html = d3p.d3print(y, 'logreg.html', return_html=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
from d3printing import *
"""Extends printing module by dynamic visualizations."""
# Authors: Christof Angermueller <cangermueller@gmail.com>
import os.path
from theano.printing import pydotprint
def replace_patterns(x, replace):
""" Replaces patterns defined by `replace` in x."""
for from_, to in replace.items():
x = x.replace(str(from_), str(to))
return x
def d3print(fct, outfile=None, return_html=False, print_message=True,
width=800, height=600,
*args, **kwargs):
"""Creates dynamic graph visualization using d3.js javascript library.
:param fct: A compiled Theano function, variable, apply or a list of
variables
:param outfile: The output file
:param return_html: If True, return HTML code
:param print_message: If True, print message at the end
:param *args, **kwargs: Parameters passed to pydotprint
"""
# Generate dot graph definition by calling pydotprint
dot_graph = pydotprint(fct, format='dot', return_image=True, *args, **kwargs)
dot_graph = dot_graph.replace('\n', ' ')
dot_graph = dot_graph.replace('node [label="\N"];', '')
# Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'template.html')
f = open(template_file)
template = f.read()
f.close()
replace = {
'%% DOT_GRAPH %%': dot_graph,
'%% WIDTH %%': width,
'%% HEIGHT %%': height
}
html = replace_patterns(template, replace)
# Write output file
if outfile is not None:
f = open(outfile, 'w')
f.write(html)
f.close()
if print_message:
print('The output file is available at %s' % (outfile))
if return_html:
return html
<!DOCTYPE html>
<html>
<head lang='en'>
<meta charset="utf-8">
<script type="text/javascript" src="http://d3js.org/d3.v3.min.js"></script>
<script type='text/javascript' src="http://cpettitt.github.io/project/dagre-d3/v0.1.5/dagre-d3.min.js"></script>
<script type='text/javascript' src="http://cpettitt.github.io/project/graphlib-dot/v0.4.10/graphlib-dot.min.js"></script>
</head>
<body>
<style>
.svg {
}
.nodeRect {
stroke: black;
border: 3px solid black;
fill: lightsteelblue;
}
.nodeText {
color: black;
font-family: courier;
}
.edge {
stroke: black;
stroke-width: 3px;
}
.edgeLabelRect {
fill: white;
}
.edgeLabelText {
fill: limegreen;
font-family: courier;
text-anchor: start;
}
.arrowHead {
stroke: black;
stroke-width: 3px;
fill: black;
}
</style>
<script type="text/javascript">
// Global attributes
var dotGraphDef = '%% DOT_GRAPH %%';
var width = %% WIDTH %%;
var height = %% HEIGHT %%;
// Add SVG element
var svg = d3.select('body').append('svg').attr('class', 'svg').attr('width', width).attr('height', height);
var pane = svg.append('g').attr('transform', 'scale(0.8)');
// Definition head of edges
svg.append("defs").append("marker")
.attr("id", 'markerEnd')
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("refX", 5)
.attr("refY", 3)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,0 L6,3 L0,6 Z")
.attr('style', 'arrowHead');
function textWidth(text) {
return text.length * 10.5;
}
// Parse dot graph definition
var graph = {};
var nodes = [];
var edges = [];
var dotGraph = graphlibDot.parse(dotGraphDef);
// TODO: parse from file
var dotWidth = 763;
var dotHeight = 388;
var scaleDotX = d3.scale.linear().domain([0, dotWidth]).range([0, width]);
var scaleDotY = d3.scale.linear().domain([0, dotHeight]).range([0, height]);
// Parse nodes
var i = 0;
for (nodeId in dotGraph._nodes) {
var node = dotGraph._nodes[nodeId];
node.index = i++;
node.value.width = textWidth(node.value.label);
node.value.height = 40;
node.value.cx = node.value.width / 2;
node.value.cy = node.value.height / 2;
node.value.pos = node.value.pos.split(',').map(function(d) {return parseInt(d);});
node.x = scaleDotX(node.value.pos[0]);
node.y = scaleDotY(dotHeight - node.value.pos[1]);
node.fixed = false;
nodes.push(node);
dotGraph._nodes[nodeId] = node;
}
// Parse edges
for (edgeId in dotGraph._edges) {
var edge = dotGraph._edges[edgeId];
edge.source = dotGraph._nodes[edge.u].index;
edge.target = dotGraph._nodes[edge.v].index;
edge.value.width = textWidth(edge.value.label);
edge.value.height = 40;
edges.push(edge);
dotGraph._edges[edgeId] = edge;
}
// Setup graph
graph['nodes'] = nodes;
graph['edges'] = edges;
// Add edges
edges = pane.append('g').attr('id', 'edges').selectAll('path').data(graph['edges']).enter().append('path')
.attr('class', 'edge')
.attr('marker-end', 'url(#markerEnd)');
// Add edge labels
edgeLabels = pane.append('g').attr('id', 'edgeLabels').selectAll('g').data(graph['edges']).enter().append('g')
.attr('opacity', 0)
.on('mouseover', function(d) {d3.select(this).attr('opacity', 1.0);})
.on('mouseout', function(d) {d3.select(this).attr('opacity', 0.0);});
var edgeLabelsRect = edgeLabels.append('rect')
.attr('class', 'edgeLabelRect')
.attr('fill', 'white')
.attr('width', function(d) {return d.value.width;})
.attr('height', function(d) {return d.value.height;});
var edgeLabelsText = edgeLabels.append('text')
.attr('class', 'edgeLabelText')
.attr('dy', function(d) {return 0.5 * d.value.height;})
.text(function(d) {return d.value.label;});
// Add nodes
nodes = pane.append('g').attr('id', 'nodes').selectAll('g').data(graph['nodes']).enter().append('g');
var nodesRect = nodes.append('rect')
.attr('class', 'nodeRect')
.attr('width', function(d) {return d.value.width;})
.attr('height', function(d) {return d.value.height;});
var nodesText = nodes.append('text')
.attr('class', 'nodeText')
.attr('x', 5)
.attr('dy', function(d) {return d.value.height - 15;})
.text(function(d) {return d.value.label;});
// Update graph
function update() {
// Update nodes
nodes.attr('transform', function(d) { return 'translate(' + d.x + ' ' + d.y + ')'; });
// Update edges
edges.attr('d', function(d) {
return 'M' + (d.source.x + d.source.value.cx) + ',' + (d.source.y + d.source.value.cy) + ' L' + (d.target.x + d.target.value.cx) + ',' + (d.target.y - 3);
});
// Update edge labels
edgeLabels.attr('transform', function(d) {
return 'translate(' + (0.5 * (d.source.x + d.source.value.cx + d.target.x + d.target.value.cx) - 0.5 * d.value.width) + ',' + (0.5 * (d.source.y + d.source.value.cy + d.target.y + d.target.value.cy) - 0.5 * d.value.height) + ')';
});
}
// Drag-start event handler
function dragStart(d) {
d3.event.sourceEvent.stopPropagation();
d.fixed = true;
}
// Zoom and translate event handler
function zoom(d) {
pane.attr('transform', 'translate(' + d3.event.translate + ') scale(' + d3.event.scale + ')');
}
// Force layout
var layout = d3.layout.force()
.nodes(graph['nodes'])
.links(graph['edges'])
.size([width, height])
.gravity(0.2)
.charge(-6000)
.linkDistance(75)
.linkStrength(0.1)
.on('tick', update);
// Drag behavour
var drag = layout.drag()
.on('dragstart', dragStart);
nodes.call(drag);
// Zoom behaviour
var bZoom = d3.behavior.zoom()
.scaleExtent([0.2, 8])
.on('zoom', zoom);
svg.call(bZoom);
// Start force layout
layout.start();
</script>
</body>
</html>
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论