Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
04192c23
提交
04192c23
authored
9月 01, 2015
作者:
Christof Angermueller
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update docstring and fix flake8 errors
上级
e8c47f1e
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
101 行增加
和
28 行删除
+101
-28
d3viz.py
theano/d3viz/d3viz.py
+45
-15
formatting.py
theano/d3viz/formatting.py
+55
-12
test_d3viz.py
theano/d3viz/tests/test_d3viz.py
+0
-1
test_flake8.py
theano/tests/test_flake8.py
+1
-0
没有找到文件。
theano/d3viz/d3viz.py
浏览文件 @
04192c23
...
@@ -8,20 +8,35 @@ import shutil
...
@@ -8,20 +8,35 @@ import shutil
import
re
import
re
from
six
import
iteritems
from
six
import
iteritems
from
.formatting
import
PyDotFormatter
from
theano.d3viz
.formatting
import
PyDotFormatter
__path__
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
__path__
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
def
replace_patterns
(
x
,
replace
):
def
replace_patterns
(
x
,
replace
):
"""Replace patterns `replace` in x."""
"""Replace `replace` in string `x`.
Parameters
----------
s: str
String on which function is applied
replace: dict
`key`, `value` pairs where key is a regular expression and `value` a
string by which `key` is replaced
"""
for
from_
,
to
in
iteritems
(
replace
):
for
from_
,
to
in
iteritems
(
replace
):
x
=
x
.
replace
(
str
(
from_
),
str
(
to
))
x
=
x
.
replace
(
str
(
from_
),
str
(
to
))
return
x
return
x
def
escape_quotes
(
s
):
def
escape_quotes
(
s
):
"""Escape quotes in string."""
"""Escape quotes in string.
Parameters
----------
s: str
String on which function is applied
"""
s
=
re
.
sub
(
r'''(['"])'''
,
r'\\\1'
,
s
)
s
=
re
.
sub
(
r'''(['"])'''
,
r'\\\1'
,
s
)
return
s
return
s
...
@@ -29,17 +44,6 @@ def escape_quotes(s):
...
@@ -29,17 +44,6 @@ def escape_quotes(s):
def
d3viz
(
fct
,
outfile
,
copy_deps
=
True
,
*
args
,
**
kwargs
):
def
d3viz
(
fct
,
outfile
,
copy_deps
=
True
,
*
args
,
**
kwargs
):
"""Create HTML file with dynamic visualizing of a Theano function graph.
"""Create HTML file with dynamic visualizing of a Theano function graph.
Parameters
----------
fct -- theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
outfile -- str
Path to output HTML file.
copy_deps -- bool, optional
Copy javascript and CSS dependencies to output directory.
*args, **kwargs -- dict, optional
Arguments passed to PyDotFormatter.
In the HTML file, the whole graph or single nodes can be moved by drag and
In the HTML file, the whole graph or single nodes can be moved by drag and
drop. Zooming is possible via the mouse wheel. Detailed information about
drop. Zooming is possible via the mouse wheel. Detailed information about
nodes and edges are displayed via mouse-over events. Node labels can be
nodes and edges are displayed via mouse-over events. Node labels can be
...
@@ -53,6 +57,19 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
...
@@ -53,6 +57,19 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
Edges are black by default. If a node returns a view of an
Edges are black by default. If a node returns a view of an
input, the input edge will be blue. If it returns a destroyed input, the
input, the input edge will be blue. If it returns a destroyed input, the
edge will be red.
edge will be red.
Parameters
----------
fct : theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
outfile : str
Path to output HTML file.
copy_deps : bool, optional
Copy javascript and CSS dependencies to output directory.
*args : tuple, optional
Arguments passed to PyDotFormatter.
*kwargs : dict, optional
Arguments passed to PyDotFormatter.
"""
"""
# Create DOT graph
# Create DOT graph
...
@@ -96,7 +113,20 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
...
@@ -96,7 +113,20 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
def
d3write
(
fct
,
path
,
*
args
,
**
kwargs
):
def
d3write
(
fct
,
path
,
*
args
,
**
kwargs
):
"""Convert Theano graph to pydot graph and write to file."""
"""Convert Theano graph to pydot graph and write to dot file.
Parameters
----------
fct : theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
path: str
Path to output file
*args : tuple, optional
Arguments passed to PyDotFormatter.
*kwargs : dict, optional
Arguments passed to PyDotFormatter.
"""
formatter
=
PyDotFormatter
(
*
args
,
**
kwargs
)
formatter
=
PyDotFormatter
(
*
args
,
**
kwargs
)
graph
=
formatter
(
fct
)
graph
=
formatter
(
fct
)
graph
.
write_dot
(
path
)
graph
.
write_dot
(
path
)
theano/d3viz/formatting.py
浏览文件 @
04192c23
...
@@ -28,14 +28,25 @@ _logger = logging.getLogger("theano.printing")
...
@@ -28,14 +28,25 @@ _logger = logging.getLogger("theano.printing")
class
PyDotFormatter
(
object
):
class
PyDotFormatter
(
object
):
"""Create `pydot` graph object from Theano function.
def
__init__
(
self
,
compact
=
True
):
"""Converts compute to to `pydot` object.
Parameters
----------
:param compact: if True, will remove intermediate variables without
compact : bool
name.
if True, will remove intermediate variables without name.
Attributes
----------
node_colors : dict
Color table of node types.
apply_colors : dict
Color table of apply nodes.
shapes : dict
Shape table of node types.
"""
"""
def
__init__
(
self
,
compact
=
True
):
"""Construct PyDotFormatter object."""
if
not
pydot_imported
:
if
not
pydot_imported
:
raise
RuntimeError
(
"Failed to import pydot. Please install pydot!"
)
raise
RuntimeError
(
"Failed to import pydot. Please install pydot!"
)
...
@@ -59,14 +70,36 @@ class PyDotFormatter(object):
...
@@ -59,14 +70,36 @@ class PyDotFormatter(object):
self
.
__node_prefix
=
'n'
self
.
__node_prefix
=
'n'
def
__add_node
(
self
,
node
):
def
__add_node
(
self
,
node
):
"""Add new node to node list and return unique id."""
"""Add new node to node list and return unique id.
Parameters
----------
node : Theano graph node
Apply node, tensor variable, or shared variable in compute graph.
Returns
-------
str
Unique node id.
"""
assert
node
not
in
self
.
__nodes
assert
node
not
in
self
.
__nodes
_id
=
'
%
s
%
d'
%
(
self
.
__node_prefix
,
len
(
self
.
__nodes
)
+
1
)
_id
=
'
%
s
%
d'
%
(
self
.
__node_prefix
,
len
(
self
.
__nodes
)
+
1
)
self
.
__nodes
[
node
]
=
_id
self
.
__nodes
[
node
]
=
_id
return
_id
return
_id
def
__node_id
(
self
,
node
):
def
__node_id
(
self
,
node
):
"""Return unique node id."""
"""Return unique node id.
Parameters
----------
node : Theano graph node
Apply node, tensor variable, or shared variable in compute graph.
Returns
-------
str
Unique node id.
"""
if
node
in
self
.
__nodes
:
if
node
in
self
.
__nodes
:
return
self
.
__nodes
[
node
]
return
self
.
__nodes
[
node
]
else
:
else
:
...
@@ -75,10 +108,18 @@ class PyDotFormatter(object):
...
@@ -75,10 +108,18 @@ class PyDotFormatter(object):
def
__call__
(
self
,
fct
,
graph
=
None
):
def
__call__
(
self
,
fct
,
graph
=
None
):
"""Create pydot graph from function.
"""Create pydot graph from function.
:param fct: a compiled Theano function, a Variable, an Apply or
Parameters
a list of Variable.
----------
:param graph: `pydot` graph to which nodes are added. Creates new one
fct : theano.compile.function_module.Function
if not given.
A compiled Theano function, variable, apply or a list of variables.
graph: pydot.Dot
`pydot` graph to which nodes are added. Creates new one if
undefined.
Returns
-------
pydot.Dot
Pydot graph of `fct`
"""
"""
if
graph
is
None
:
if
graph
is
None
:
graph
=
pd
.
Dot
()
graph
=
pd
.
Dot
()
...
@@ -287,6 +328,7 @@ def apply_profile(node, profile):
...
@@ -287,6 +328,7 @@ def apply_profile(node, profile):
def
broadcastable_to_str
(
b
):
def
broadcastable_to_str
(
b
):
"""Return string representation of broadcastable."""
named_broadcastable
=
{():
'scalar'
,
named_broadcastable
=
{():
'scalar'
,
(
False
,):
'vector'
,
(
False
,):
'vector'
,
(
False
,
True
):
'col'
,
(
False
,
True
):
'col'
,
...
@@ -300,6 +342,7 @@ def broadcastable_to_str(b):
...
@@ -300,6 +342,7 @@ def broadcastable_to_str(b):
def
dtype_to_char
(
dtype
):
def
dtype_to_char
(
dtype
):
"""Return character that represents data type."""
dtype_char
=
{
dtype_char
=
{
'complex64'
:
'c'
,
'complex64'
:
'c'
,
'complex128'
:
'z'
,
'complex128'
:
'z'
,
...
...
theano/d3viz/tests/test_d3viz.py
浏览文件 @
04192c23
...
@@ -3,7 +3,6 @@ import os.path as pt
...
@@ -3,7 +3,6 @@ import os.path as pt
import
tempfile
import
tempfile
import
unittest
import
unittest
import
filecmp
import
filecmp
import
sys
import
theano
as
th
import
theano
as
th
import
theano.d3viz
as
d3v
import
theano.d3viz
as
d3v
...
...
theano/tests/test_flake8.py
浏览文件 @
04192c23
...
@@ -216,6 +216,7 @@ whitelist_flake8 = [
...
@@ -216,6 +216,7 @@ whitelist_flake8 = [
"gof/tests/test_cc.py"
,
"gof/tests/test_cc.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/sandbox/equilibrium.py"
,
"gof/sandbox/equilibrium.py"
,
"d3viz/__init__.py"
]
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论