提交 458e1594 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3772 from gw0/feat-pydot-fallback

Import pydot-ng with pydot as fallback.
from theano.d3viz.d3viz import d3viz, d3write from theano.d3viz.d3viz import d3viz, d3write
has_requirements = False
try:
import pydot as pd
if pd.find_graphviz():
has_requirements = True
except ImportError:
pass
...@@ -16,11 +16,18 @@ from theano.compile import builders ...@@ -16,11 +16,18 @@ from theano.compile import builders
pydot_imported = False pydot_imported = False
try: try:
import pydot as pd # pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz(): if pd.find_graphviz():
pydot_imported = True pydot_imported = True
except ImportError: except ImportError:
pass try:
# fall back on pydot if necessary
import pydot as pd
if pd.find_graphviz():
pydot_imported = True
except ImportError:
pass # tests should not fail on optional dependency
class PyDotFormatter(object): class PyDotFormatter(object):
......
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.d3viz import has_requirements from theano.d3viz.formatting import pydot_imported
if not has_requirements: if not pydot_imported:
raise SkipTest('Missing requirements') raise SkipTest('Missing requirements')
import numpy as np import numpy as np
......
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.d3viz import has_requirements from theano.d3viz.formatting import pydot_imported
if not has_requirements: if not pydot_imported:
raise SkipTest('Missing requirements') raise SkipTest('Missing requirements')
import numpy as np import numpy as np
......
...@@ -22,25 +22,20 @@ from theano.gof import Op, Apply ...@@ -22,25 +22,20 @@ from theano.gof import Op, Apply
from theano.compile import Function, debugmode, SharedVariable from theano.compile import Function, debugmode, SharedVariable
from theano.compile.profilemode import ProfileMode from theano.compile.profilemode import ProfileMode
# pydot-ng is a fork of pydot that is better maintained, and works pydot_imported = False
# with more recent version of its dependencies (in particular pyparsing)
try: try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd import pydot_ng as pd
if pd.find_graphviz(): if pd.find_graphviz():
pydot_imported = True pydot_imported = True
else: except ImportError:
pydot_imported = False
except Exception:
# Sometimes, a Windows-specific exception is raised
pydot_imported = False
# Fall back on pydot if necessary
if not pydot_imported:
try: try:
# fall back on pydot if necessary
import pydot as pd import pydot as pd
if pd.find_graphviz(): if pd.find_graphviz():
pydot_imported = True pydot_imported = True
except Exception: except ImportError:
pass pass # tests should not fail on optional dependency
_logger = logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
VALID_ASSOC = set(['left', 'right', 'either']) VALID_ASSOC = set(['left', 'right', 'either'])
...@@ -733,7 +728,6 @@ def pydotprint(fct, outfile=None, ...@@ -733,7 +728,6 @@ def pydotprint(fct, outfile=None,
if not pydot_imported: if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot" raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint` to work.") " and graphviz for `pydotprint` to work.")
return
g = pd.Dot() g = pd.Dot()
...@@ -1065,13 +1059,10 @@ def pydotprint_variables(vars, ...@@ -1065,13 +1059,10 @@ def pydotprint_variables(vars,
if outfile is None: if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' + outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format) config.device + '.' + format)
try: if not pydot_imported:
import pydot as pd raise RuntimeError("Failed to import pydot. You must install pydot"
except ImportError: " and graphviz for `pydotprint_variables` to work.")
err = ("Failed to import pydot. You must install pydot for " +
"`pydotprint_variables` to work.")
print(err)
return
g = pd.Dot() g = pd.Dot()
my_list = {} my_list = {}
orphanes = [] orphanes = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论