提交 476ad0e1 authored 作者: Christof Angermueller's avatar Christof Angermueller

Update d3print method to write dot and html file

上级 4d38e314
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Table of Contents\n",
"* [Model](#Model)\n",
"* [Example 1](#Example-1)\n",
"* [Example 2](#Example-2)\n",
"* [Example 3](#Example-3)\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n"
]
}
],
"source": [
"import numpy\n",
"import theano\n",
"import theano.tensor as T\n",
"import theano.printing as pr\n",
"import theano.d3printing as d3p\n",
"rng = numpy.random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Training data\n",
"N = 400\n",
"feats = 784\n",
"D = (rng.randn(N, feats).astype(theano.config.floatX), rng.randint(size=N,low=0, high=2).astype(theano.config.floatX))\n",
"training_steps = 10000\n",
"\n",
"# Declare Theano symbolic variables\n",
"x = T.matrix(\"x\")\n",
"y = T.vector(\"y\")\n",
"w = theano.shared(rng.randn(feats).astype(theano.config.floatX), name=\"w\")\n",
"b = theano.shared(numpy.asarray(0., dtype=theano.config.floatX), name=\"b\")\n",
"x.tag.test_value = D[0]\n",
"y.tag.test_value = D[1]\n",
"\n",
"# Construct Theano expression graph\n",
"p_1 = 1 / (1 + T.exp(-T.dot(x, w)-b)) # Probability of having a one\n",
"prediction = p_1 > 0.5 # The prediction that is done: 0 or 1\n",
"\n",
"# Compute gradients\n",
"xent = -y*T.log(p_1) - (1-y)*T.log(1-p_1) # Cross-entropy\n",
"cost = xent.mean() + 0.01*(w**2).sum() # The cost to optimize\n",
"gw,gb = T.grad(cost, [w,b])\n",
"\n",
"# Training and prediction function\n",
"train = theano.function(inputs=[x,y], outputs=[prediction, xent], updates=[[w, w-0.01*gw], [b, b-0.01*gb]], name = \"train\")\n",
"predict = theano.function(inputs=[x], outputs=prediction, name = \"predict\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example 1 "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at p1.png\n",
"The output file is available at p1.html\n"
]
}
],
"source": [
"pr.pydotprint(p_1, outfile='p1.png', var_with_name_simple=True)\n",
"d3p.d3print(p_1, 'p1.html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href='p1.html'><img src='p1.png'/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[open](./p1.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example 2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at predict.png\n",
"The output file is available at predict.html\n"
]
}
],
"source": [
"pr.pydotprint(predict, outfile='predict.png', var_with_name_simple=True)\n",
"d3p.d3print(predict, 'predict.html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href='predict.html'><img src='predict.png'/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[open](./predict.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example 3"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at train.png\n",
"The output file is available at train.html\n"
]
}
],
"source": [
"pr.pydotprint(train, outfile='train.png', var_with_name_simple=True)\n",
"d3p.d3print(train, 'train.html')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href='train.html'><img src='train.png'/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[open](./train.html)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
digraph G { graph [bb="0,0,719,672"]; "DimShuffle{x}" [height=0.5, pos="558,478", shape=ellipse, width=1.8374]; "Elemwise{sub,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="461,390", shape=ellipse, style=filled, width=3.0624]; "DimShuffle{x}" -> "Elemwise{sub,no_inplace}" [label="1 TensorType(float64, (True,))", lp="632.5,434", pos="e,506.93,406.57 553.52,459.62 549.99,448.82 544.12,435.29 535,426 529.47,420.37 522.85,415.53 515.89,411.41"]; "name=b TensorType(float64, scalar)" [fillcolor=limegreen, height=0.5, pos="558,566", shape=box, style=filled, width=3.0625]; "name=b TensorType(float64, scalar)" -> "DimShuffle{x}" [label="TensorType(float64, scalar)", lp="636,522", pos="e,558,496.08 558,547.6 558,535.75 558,519.82 558,506.29"]; dot [height=0.5, pos="363,566", shape=ellipse, width=0.75]; "Elemwise{neg,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="363,478", shape=ellipse, style=filled, width=3.0624]; dot -> "Elemwise{neg,no_inplace}" [label="TensorType(float64, vector)", lp="442,522", pos="e,363,496.08 363,547.6 363,535.75 363,519.82 363,506.29"]; "name=x TensorType(float64, matrix)" [fillcolor=limegreen, height=0.5, pos="241,654", shape=box, style=filled, width=3.1181]; "name=x TensorType(float64, matrix)" -> dot [label="0 TensorType(float64, matrix)", lp="378,610", pos="e,340.56,576.07 255.77,636 265.61,625.35 279.21,611.84 293,602 304.81,593.57 318.95,586.09 331.42,580.22"]; "name=w TensorType(float64, vector)" [fillcolor=limegreen, height=0.5, pos="485,654", shape=box, style=filled, width=3.1389]; "name=w TensorType(float64, vector)" -> dot [label="1 TensorType(float64, vector)", lp="559.5,610", pos="e,389.41,570.28 481.35,635.78 478.2,624.78 472.59,610.94 463,602 445.54,585.73 420.08,576.96 399.23,572.27"]; "DimShuffle{x} id=2" [height=0.5, pos="247,302", shape=ellipse, width=2.3721]; "Elemwise{add,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="369,194", shape=ellipse, style=filled, width=3.0624]; "DimShuffle{x} id=2" -> "Elemwise{add,no_inplace}" [label="0 TensorType(int8, (True,))", lp="334,248", pos="e,290.21,206.6 244.34,283.76 242.92,267.96 243.54,244.85 256,230 262.8,221.89 271.45,215.63 280.91,210.8"]; "val=1 TensorType(int8, scalar)" [fillcolor=limegreen, height=0.5, pos="161,390", shape=box, style=filled, width=2.6389]; "val=1 TensorType(int8, scalar)" -> "DimShuffle{x} id=2" [label="TensorType(int8, scalar)", lp="305.5,346", pos="e,242.7,320.2 201.81,371.98 210.24,367.09 218.52,361.12 225,354 231.25,347.13 235.92,338.22 239.3,329.76"]; "DimShuffle{x} id=3" [height=0.5, pos="85,248", shape=ellipse, width=2.3721]; "val=1 TensorType(int8, scalar)" -> "DimShuffle{x} id=3" [label="TensorType(int8, scalar)", lp="151.5,346", pos="e,78.947,266.23 101.23,371.9 93.625,367.28 86.827,361.42 82,354 66.992,330.94 70.718,298.58 76.308,275.9"]; "Elemwise{true_div,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="164,106", shape=ellipse, style=filled, width=3.5561]; "DimShuffle{x} id=3" -> "Elemwise{true_div,no_inplace}" [label="0 TensorType(int8, (True,))", lp="172,194", pos="e,140.09,123.74 84.103,229.75 84.024,214.78 85.627,192.93 94,176 102.79,158.23 118.02,142.37 132.04,130.36"]; "Elemwise{neg,no_inplace}" -> "Elemwise{sub,no_inplace}" [label="0 TensorType(float64, vector)", lp="446.5,434", pos="e,389.83,403.99 357.87,459.89 355.74,449.21 355.15,435.69 362,426 366.99,418.94 373.53,413.29 380.88,408.78"]; "Elemwise{exp,no_inplace}" [fillcolor="#FFAABB", height=0.5, pos="461,302", shape=ellipse, style=filled, width=3.0624]; "Elemwise{sub,no_inplace}" -> "Elemwise{exp,no_inplace}" [label="TensorType(float64, vector)", lp="540,346", pos="e,461,320.08 461,371.6 461,359.75 461,343.82 461,330.29"]; "Elemwise{exp,no_inplace}" -> "Elemwise{add,no_inplace}" [label="1 TensorType(float64, vector)", lp="526.5,248", pos="e,394.8,211.6 452.54,284.02 444.46,268.82 431.28,246.48 416,230 412.01,225.7 407.43,221.56 402.74,217.73"]; "Elemwise{add,no_inplace}" -> "Elemwise{true_div,no_inplace}" [label="1 TensorType(float64, vector)", lp="369.5,150", pos="e,202.57,123.18 330.92,177.03 297.33,162.93 248.22,142.33 211.96,127.12"]; "TensorType(float64, vector) id=12" [fillcolor=dodgerblue, height=0.5, pos="164,18", shape=box, style=filled, width=2.9236]; "Elemwise{true_div,no_inplace}" -> "TensorType(float64, vector) id=12" [label="TensorType(float64, vector)", lp="243,62", pos="e,164,36.084 164,87.597 164,75.746 164,59.817 164,46.292"]; }
\ No newline at end of file
差异被折叠。
差异被折叠。
差异被折叠。
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Authors: Christof Angermueller <cangermueller@gmail.com> # Authors: Christof Angermueller <cangermueller@gmail.com>
import os.path import os.path
from theano.printing import pydotprint from theano.printing import pydotprint
...@@ -46,8 +47,11 @@ def d3print(fct, outfile=None, return_html=False, print_message=True, ...@@ -46,8 +47,11 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
:param *args, **kwargs: Parameters passed to pydotprint :param *args, **kwargs: Parameters passed to pydotprint
""" """
# Generate dot graph definition by calling pydotprint # Generate dot graph by pydotprint and write to file
dot_graph = d3dot(fct, *args, **kwargs) dot_graph = d3dot(fct, *args, **kwargs)
dot_file = os.path.splitext(outfile)[0] + '.dot'
with open(dot_file, 'w') as f:
f.write(dot_graph)
# Read template HTML file and replace variables # Read template HTML file and replace variables
template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), template_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
...@@ -56,7 +60,7 @@ def d3print(fct, outfile=None, return_html=False, print_message=True, ...@@ -56,7 +60,7 @@ def d3print(fct, outfile=None, return_html=False, print_message=True,
template = f.read() template = f.read()
f.close() f.close()
replace = { replace = {
'%% DOT_GRAPH %%': dot_graph, '%% DOT_FILE %%': dot_file,
} }
html = replace_patterns(template, replace) html = replace_patterns(template, replace)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论