提交 798941fd authored 作者: Christof Angermueller's avatar Christof Angermueller

Use six to iterate over dict

上级 57f7590e
......@@ -6,6 +6,7 @@ Author: Christof Angermueller <cangermueller@gmail.com>
import os
import shutil
import re
from six import iteritems
from .formatting import PyDotFormatter
......@@ -14,7 +15,7 @@ __path__ = os.path.dirname(os.path.realpath(__file__))
def replace_patterns(x, replace):
"""Replace patterns `replace` in x."""
for from_, to in replace.items():
for from_, to in iteritems(replace):
x = x.replace(str(from_), str(to))
return x
......@@ -28,11 +29,16 @@ def escape_quotes(s):
def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
"""Create HTML file with dynamic visualizing of a Theano function graph.
:param fct: A compiled Theano function, variable, apply or a list of
variables.
:param outfile: The output HTML file.
:param copy_deps: Copy javascript and CSS dependencies to output directory.
:param *args, **kwargs: Arguments passed to PyDotFormatter.
Parameters
----------
fct -- theano.compile.function_module.Function
A compiled Theano function, variable, apply or a list of variables.
outfile -- str
Path to output HTML file.
copy_deps -- bool, optional
Copy javascript and CSS dependencies to output directory.
*args, **kwargs -- dict, optional
Arguments passed to PyDotFormatter.
In the HTML file, the whole graph or single nodes can be moved by drag and
drop. Zooming is possible via the mouse wheel. Detailed information about
......@@ -85,10 +91,8 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
html = replace_patterns(template, replace)
# Write HTML file
if outfile is not None:
f = open(outfile, 'w')
with open(outfile, 'w') as f:
f.write(html)
f.close()
def d3write(fct, path, *args, **kwargs):
......
......@@ -5,7 +5,9 @@ Author: Christof Angermueller <cangermueller@gmail.com>
import numpy as np
import logging
import os
from functools import reduce
from six import iteritems, itervalues
try:
import pydot as pd
......@@ -35,8 +37,7 @@ class PyDotFormatter(object):
"""
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" for `pydotprint` to work.")
raise RuntimeError("Failed to import pydot. Please install pydot!")
self.compact = compact
self.node_colors = {'input': 'limegreen',
......@@ -124,7 +125,7 @@ class PyDotFormatter(object):
nparams['shape'] = self.shapes['apply']
use_color = None
for opName, color in self.apply_colors.items():
for opName, color in iteritems(self.apply_colors):
if opName in node.op.__class__.__name__:
use_color = color
if use_color:
......@@ -160,11 +161,11 @@ class PyDotFormatter(object):
edge_params = {}
if hasattr(node.op, 'view_map') and \
id in reduce(list.__add__,
node.op.view_map.values(), []):
itervalues(node.op.view_map), []):
edge_params['color'] = self.node_colors['output']
elif hasattr(node.op, 'destroy_map') and \
id in reduce(list.__add__,
node.op.destroy_map.values(), []):
itervalues(node.op.destroy_map), []):
edge_params['color'] = 'red'
edge_label = vparams['dtype']
......@@ -329,7 +330,7 @@ def type_to_str(t):
def dict_to_pdnode(d):
"""Create pydot node from dict."""
e = dict()
for k, v in d.items():
for k, v in iteritems(d):
if v is not None:
if isinstance(v, list):
v = '\t'.join([str(x) for x in v])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论