Import pydot-ng with pydot as fallback.

上级 d7e653ab
......@@ -2,8 +2,10 @@ from theano.d3viz.d3viz import d3viz, d3write
has_requirements = False
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
except ImportError:
# fall back on pydot if necessary
import pydot as pd
if pd.find_graphviz():
if pd.find_graphviz():
has_requirements = True
except ImportError:
pass
......@@ -16,11 +16,13 @@ from theano.compile import builders
pydot_imported = False
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
except ImportError:
# fall back on pydot if necessary
import pydot as pd
if pd.find_graphviz():
if pd.find_graphviz():
pydot_imported = True
except ImportError:
pass
class PyDotFormatter(object):
......
......@@ -22,25 +22,15 @@ from theano.gof import Op, Apply
from theano.compile import Function, debugmode, SharedVariable
from theano.compile.profilemode import ProfileMode
# pydot-ng is a fork of pydot that is better maintained, and works
# with more recent version of its dependencies (in particular pyparsing)
pydot_imported = False
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported = False
except Exception:
# Sometimes, a Windows-specific exception is raised
pydot_imported = False
# Fall back on pydot if necessary
if not pydot_imported:
try:
except ImportError:
# fall back on pydot if necessary
import pydot as pd
if pd.find_graphviz():
if pd.find_graphviz():
pydot_imported = True
except Exception:
pass
_logger = logging.getLogger("theano.printing")
VALID_ASSOC = set(['left', 'right', 'either'])
......@@ -733,7 +723,6 @@ def pydotprint(fct, outfile=None,
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint` to work.")
return
g = pd.Dot()
......@@ -1065,13 +1054,10 @@ def pydotprint_variables(vars,
if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format)
try:
import pydot as pd
except ImportError:
err = ("Failed to import pydot. You must install pydot for " +
"`pydotprint_variables` to work.")
print(err)
return
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint_variables` to work.")
g = pd.Dot()
my_list = {}
orphanes = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论