提交 798941fd authored 作者: Christof Angermueller's avatar Christof Angermueller

Use six to iterate over dict

上级 57f7590e
...@@ -6,6 +6,7 @@ Author: Christof Angermueller <cangermueller@gmail.com> ...@@ -6,6 +6,7 @@ Author: Christof Angermueller <cangermueller@gmail.com>
import os import os
import shutil import shutil
import re import re
from six import iteritems
from .formatting import PyDotFormatter from .formatting import PyDotFormatter
...@@ -14,7 +15,7 @@ __path__ = os.path.dirname(os.path.realpath(__file__)) ...@@ -14,7 +15,7 @@ __path__ = os.path.dirname(os.path.realpath(__file__))
def replace_patterns(x, replace): def replace_patterns(x, replace):
"""Replace patterns `replace` in x.""" """Replace patterns `replace` in x."""
for from_, to in replace.items(): for from_, to in iteritems(replace):
x = x.replace(str(from_), str(to)) x = x.replace(str(from_), str(to))
return x return x
...@@ -28,11 +29,16 @@ def escape_quotes(s): ...@@ -28,11 +29,16 @@ def escape_quotes(s):
def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
"""Create HTML file with dynamic visualizing of a Theano function graph. """Create HTML file with dynamic visualizing of a Theano function graph.
:param fct: A compiled Theano function, variable, apply or a list of Parameters
variables. ----------
:param outfile: The output HTML file. fct -- theano.compile.function_module.Function
:param copy_deps: Copy javascript and CSS dependencies to output directory. A compiled Theano function, variable, apply or a list of variables.
:param *args, **kwargs: Arguments passed to PyDotFormatter. outfile -- str
Path to output HTML file.
copy_deps -- bool, optional
Copy javascript and CSS dependencies to output directory.
*args, **kwargs -- dict, optional
Arguments passed to PyDotFormatter.
In the HTML file, the whole graph or single nodes can be moved by drag and In the HTML file, the whole graph or single nodes can be moved by drag and
drop. Zooming is possible via the mouse wheel. Detailed information about drop. Zooming is possible via the mouse wheel. Detailed information about
...@@ -85,10 +91,8 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -85,10 +91,8 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
html = replace_patterns(template, replace) html = replace_patterns(template, replace)
# Write HTML file # Write HTML file
if outfile is not None: with open(outfile, 'w') as f:
f = open(outfile, 'w')
f.write(html) f.write(html)
f.close()
def d3write(fct, path, *args, **kwargs): def d3write(fct, path, *args, **kwargs):
......
...@@ -5,7 +5,9 @@ Author: Christof Angermueller <cangermueller@gmail.com> ...@@ -5,7 +5,9 @@ Author: Christof Angermueller <cangermueller@gmail.com>
import numpy as np import numpy as np
import logging import logging
import os
from functools import reduce from functools import reduce
from six import iteritems, itervalues
try: try:
import pydot as pd import pydot as pd
...@@ -35,8 +37,7 @@ class PyDotFormatter(object): ...@@ -35,8 +37,7 @@ class PyDotFormatter(object):
""" """
if not pydot_imported: if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot" raise RuntimeError("Failed to import pydot. Please install pydot!")
" for `pydotprint` to work.")
self.compact = compact self.compact = compact
self.node_colors = {'input': 'limegreen', self.node_colors = {'input': 'limegreen',
...@@ -124,7 +125,7 @@ class PyDotFormatter(object): ...@@ -124,7 +125,7 @@ class PyDotFormatter(object):
nparams['shape'] = self.shapes['apply'] nparams['shape'] = self.shapes['apply']
use_color = None use_color = None
for opName, color in self.apply_colors.items(): for opName, color in iteritems(self.apply_colors):
if opName in node.op.__class__.__name__: if opName in node.op.__class__.__name__:
use_color = color use_color = color
if use_color: if use_color:
...@@ -160,11 +161,11 @@ class PyDotFormatter(object): ...@@ -160,11 +161,11 @@ class PyDotFormatter(object):
edge_params = {} edge_params = {}
if hasattr(node.op, 'view_map') and \ if hasattr(node.op, 'view_map') and \
id in reduce(list.__add__, id in reduce(list.__add__,
node.op.view_map.values(), []): itervalues(node.op.view_map), []):
edge_params['color'] = self.node_colors['output'] edge_params['color'] = self.node_colors['output']
elif hasattr(node.op, 'destroy_map') and \ elif hasattr(node.op, 'destroy_map') and \
id in reduce(list.__add__, id in reduce(list.__add__,
node.op.destroy_map.values(), []): itervalues(node.op.destroy_map), []):
edge_params['color'] = 'red' edge_params['color'] = 'red'
edge_label = vparams['dtype'] edge_label = vparams['dtype']
...@@ -329,7 +330,7 @@ def type_to_str(t): ...@@ -329,7 +330,7 @@ def type_to_str(t):
def dict_to_pdnode(d): def dict_to_pdnode(d):
"""Create pydot node from dict.""" """Create pydot node from dict."""
e = dict() e = dict()
for k, v in d.items(): for k, v in iteritems(d):
if v is not None: if v is not None:
if isinstance(v, list): if isinstance(v, list):
v = '\t'.join([str(x) for x in v]) v = '\t'.join([str(x) for x in v])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论