提交 04192c23 authored 作者: Christof Angermueller's avatar Christof Angermueller

Update docstring and fix flake8 errors

上级 e8c47f1e
...@@ -8,20 +8,35 @@ import shutil ...@@ -8,20 +8,35 @@ import shutil
import re import re
from six import iteritems from six import iteritems
from .formatting import PyDotFormatter from theano.d3viz.formatting import PyDotFormatter
__path__ = os.path.dirname(os.path.realpath(__file__)) __path__ = os.path.dirname(os.path.realpath(__file__))
def replace_patterns(x, replace): def replace_patterns(x, replace):
"""Replace patterns `replace` in x.""" """Replace `replace` in string `x`.
Parameters
----------
s: str
String on which function is applied
replace: dict
`key`, `value` pairs where key is a regular expression and `value` a
string by which `key` is replaced
"""
for from_, to in iteritems(replace): for from_, to in iteritems(replace):
x = x.replace(str(from_), str(to)) x = x.replace(str(from_), str(to))
return x return x
def escape_quotes(s): def escape_quotes(s):
"""Escape quotes in string.""" """Escape quotes in string.
Parameters
----------
s: str
String on which function is applied
"""
s = re.sub(r'''(['"])''', r'\\\1', s) s = re.sub(r'''(['"])''', r'\\\1', s)
return s return s
...@@ -29,17 +44,6 @@ def escape_quotes(s): ...@@ -29,17 +44,6 @@ def escape_quotes(s):
def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
"""Create HTML file with dynamic visualizing of a Theano function graph. """Create HTML file with dynamic visualizing of a Theano function graph.
Parameters
----------
fct -- theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
outfile -- str
Path to output HTML file.
copy_deps -- bool, optional
Copy javascript and CSS dependencies to output directory.
*args, **kwargs -- dict, optional
Arguments passed to PyDotFormatter.
In the HTML file, the whole graph or single nodes can be moved by drag and In the HTML file, the whole graph or single nodes can be moved by drag and
drop. Zooming is possible via the mouse wheel. Detailed information about drop. Zooming is possible via the mouse wheel. Detailed information about
nodes and edges are displayed via mouse-over events. Node labels can be nodes and edges are displayed via mouse-over events. Node labels can be
...@@ -53,6 +57,19 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -53,6 +57,19 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
Edges are black by default. If a node returns a view of an Edges are black by default. If a node returns a view of an
input, the input edge will be blue. If it returns a destroyed input, the input, the input edge will be blue. If it returns a destroyed input, the
edge will be red. edge will be red.
Parameters
----------
fct : theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
outfile : str
Path to output HTML file.
copy_deps : bool, optional
Copy javascript and CSS dependencies to output directory.
*args : tuple, optional
Arguments passed to PyDotFormatter.
*kwargs : dict, optional
Arguments passed to PyDotFormatter.
""" """
# Create DOT graph # Create DOT graph
...@@ -96,7 +113,20 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -96,7 +113,20 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
def d3write(fct, path, *args, **kwargs): def d3write(fct, path, *args, **kwargs):
"""Convert Theano graph to pydot graph and write to file.""" """Convert Theano graph to pydot graph and write to dot file.
Parameters
----------
fct : theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
path: str
Path to output file
*args : tuple, optional
Arguments passed to PyDotFormatter.
*kwargs : dict, optional
Arguments passed to PyDotFormatter.
"""
formatter = PyDotFormatter(*args, **kwargs) formatter = PyDotFormatter(*args, **kwargs)
graph = formatter(fct) graph = formatter(fct)
graph.write_dot(path) graph.write_dot(path)
...@@ -28,14 +28,25 @@ _logger = logging.getLogger("theano.printing") ...@@ -28,14 +28,25 @@ _logger = logging.getLogger("theano.printing")
class PyDotFormatter(object): class PyDotFormatter(object):
"""Create `pydot` graph object from Theano function.
Parameters
----------
compact : bool
if True, will remove intermediate variables without name.
Attributes
----------
node_colors : dict
Color table of node types.
apply_colors : dict
Color table of apply nodes.
shapes : dict
Shape table of node types.
"""
def __init__(self, compact=True): def __init__(self, compact=True):
"""Converts compute to to `pydot` object. """Construct PyDotFormatter object."""
:param compact: if True, will remove intermediate variables without
name.
"""
if not pydot_imported: if not pydot_imported:
raise RuntimeError("Failed to import pydot. Please install pydot!") raise RuntimeError("Failed to import pydot. Please install pydot!")
...@@ -59,14 +70,36 @@ class PyDotFormatter(object): ...@@ -59,14 +70,36 @@ class PyDotFormatter(object):
self.__node_prefix = 'n' self.__node_prefix = 'n'
def __add_node(self, node): def __add_node(self, node):
"""Add new node to node list and return unique id.""" """Add new node to node list and return unique id.
Parameters
----------
node : Theano graph node
Apply node, tensor variable, or shared variable in compute graph.
Returns
-------
str
Unique node id.
"""
assert node not in self.__nodes assert node not in self.__nodes
_id = '%s%d' % (self.__node_prefix, len(self.__nodes) + 1) _id = '%s%d' % (self.__node_prefix, len(self.__nodes) + 1)
self.__nodes[node] = _id self.__nodes[node] = _id
return _id return _id
def __node_id(self, node): def __node_id(self, node):
"""Return unique node id.""" """Return unique node id.
Parameters
----------
node : Theano graph node
Apply node, tensor variable, or shared variable in compute graph.
Returns
-------
str
Unique node id.
"""
if node in self.__nodes: if node in self.__nodes:
return self.__nodes[node] return self.__nodes[node]
else: else:
...@@ -75,10 +108,18 @@ class PyDotFormatter(object): ...@@ -75,10 +108,18 @@ class PyDotFormatter(object):
def __call__(self, fct, graph=None): def __call__(self, fct, graph=None):
"""Create pydot graph from function. """Create pydot graph from function.
:param fct: a compiled Theano function, a Variable, an Apply or Parameters
a list of Variable. ----------
:param graph: `pydot` graph to which nodes are added. Creates new one fct : theano.compile.function_module.Function
if not given. A compiled Theano function, variable, apply or a list of variables.
graph: pydot.Dot
`pydot` graph to which nodes are added. Creates new one if
undefined.
Returns
-------
pydot.Dot
Pydot graph of `fct`
""" """
if graph is None: if graph is None:
graph = pd.Dot() graph = pd.Dot()
...@@ -287,6 +328,7 @@ def apply_profile(node, profile): ...@@ -287,6 +328,7 @@ def apply_profile(node, profile):
def broadcastable_to_str(b): def broadcastable_to_str(b):
"""Return string representation of broadcastable."""
named_broadcastable = {(): 'scalar', named_broadcastable = {(): 'scalar',
(False,): 'vector', (False,): 'vector',
(False, True): 'col', (False, True): 'col',
...@@ -300,6 +342,7 @@ def broadcastable_to_str(b): ...@@ -300,6 +342,7 @@ def broadcastable_to_str(b):
def dtype_to_char(dtype): def dtype_to_char(dtype):
"""Return character that represents data type."""
dtype_char = { dtype_char = {
'complex64': 'c', 'complex64': 'c',
'complex128': 'z', 'complex128': 'z',
......
...@@ -3,7 +3,6 @@ import os.path as pt ...@@ -3,7 +3,6 @@ import os.path as pt
import tempfile import tempfile
import unittest import unittest
import filecmp import filecmp
import sys
import theano as th import theano as th
import theano.d3viz as d3v import theano.d3viz as d3v
......
...@@ -216,6 +216,7 @@ whitelist_flake8 = [ ...@@ -216,6 +216,7 @@ whitelist_flake8 = [
"gof/tests/test_cc.py", "gof/tests/test_cc.py",
"gof/tests/test_compute_test_value.py", "gof/tests/test_compute_test_value.py",
"gof/sandbox/equilibrium.py", "gof/sandbox/equilibrium.py",
"d3viz/__init__.py"
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论