提交 84c9c685 authored 作者: Frederic Bastien's avatar Frederic Bastien

Better pydotprint error message

上级 dba69909
......@@ -23,19 +23,26 @@ from theano.compile import Function, debugmode, SharedVariable
from theano.compile.profilemode import ProfileMode
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz"
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
except ImportError:
pass # tests should not fail on optional dependency
# tests should not fail on optional dependency
pydot_imported_msg = "Install the python package pydot or pydot-ng."
_logger = logging.getLogger("theano.printing")
VALID_ASSOC = set(['left', 'right', 'either'])
......@@ -728,7 +735,8 @@ def pydotprint(fct, outfile=None,
topo = fct.toposort()
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint` to work.")
" and graphviz for `pydotprint` to work.",
pydot_imported_msg)
g = pd.Dot()
......@@ -1063,7 +1071,8 @@ def pydotprint_variables(vars,
config.device + '.' + format)
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint_variables` to work.")
" and graphviz for `pydotprint_variables` to work.",
pydot_imported_msg)
if pd.__name__ == "pydot_ng":
raise RuntimeError("pydotprint_variables do not support pydot_ng."
"pydotprint_variables is also deprecated, "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论