提交 2d9b16b9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make SkipTest give a better error. Remove duplicate code

上级 3896827b
...@@ -14,25 +14,7 @@ from theano import gof ...@@ -14,25 +14,7 @@ from theano import gof
from theano.compile.profilemode import ProfileMode from theano.compile.profilemode import ProfileMode
from theano.compile import Function from theano.compile import Function
from theano.compile import builders from theano.compile import builders
from theano.printing import pd, pydot_imported, pydot_imported_msg
pydot_imported = False
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd
if pd.find_graphviz():
pydot_imported = True
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd
if hasattr(pd, 'find_graphviz'):
if pd.find_graphviz():
pydot_imported = True
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
pass # tests should not fail on optional dependency
class PyDotFormatter(object): class PyDotFormatter(object):
...@@ -56,9 +38,8 @@ class PyDotFormatter(object): ...@@ -56,9 +38,8 @@ class PyDotFormatter(object):
def __init__(self, compact=True): def __init__(self, compact=True):
"""Construct PyDotFormatter object.""" """Construct PyDotFormatter object."""
if not pydot_imported: if not pydot_imported:
raise ImportError('Failed to import pydot. You must install ' raise ImportError('Failed to import pydot. ' +
'graphviz and either pydot or pydot-ng for ' pydot_imported_msg)
'`PyDotFormatter` to work.')
self.compact = compact self.compact = compact
self.node_colors = {'input': 'limegreen', self.node_colors = {'input': 'limegreen',
......
...@@ -11,9 +11,9 @@ import theano.d3viz as d3v ...@@ -11,9 +11,9 @@ import theano.d3viz as d3v
from theano.d3viz.tests import models from theano.d3viz.tests import models
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.d3viz.formatting import pydot_imported from theano.d3viz.formatting import pydot_imported, pydot_imported_msg
if not pydot_imported: if not pydot_imported:
raise SkipTest('Missing requirements') raise SkipTest('pydot not available: ' + pydot_imported_msg)
class TestD3Viz(unittest.TestCase): class TestD3Viz(unittest.TestCase):
......
...@@ -8,9 +8,9 @@ from theano.d3viz.formatting import PyDotFormatter ...@@ -8,9 +8,9 @@ from theano.d3viz.formatting import PyDotFormatter
from theano.d3viz.tests import models from theano.d3viz.tests import models
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.d3viz.formatting import pydot_imported from theano.d3viz.formatting import pydot_imported, pydot_imported_msg
if not pydot_imported: if not pydot_imported:
raise SkipTest('Missing requirements') raise SkipTest('pydot not available: ' + pydot_imported_msg)
class TestPyDotFormatter(unittest.TestCase): class TestPyDotFormatter(unittest.TestCase):
......
...@@ -30,7 +30,7 @@ try: ...@@ -30,7 +30,7 @@ try:
if pd.find_graphviz(): if pd.find_graphviz():
pydot_imported = True pydot_imported = True
else: else:
pydot_imported_msg = "pydot-ng can't find graphviz" pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError: except ImportError:
try: try:
# fall back on pydot if necessary # fall back on pydot if necessary
...@@ -45,7 +45,8 @@ except ImportError: ...@@ -45,7 +45,8 @@ except ImportError:
pydot_imported = True pydot_imported = True
except ImportError: except ImportError:
# tests should not fail on optional dependency # tests should not fail on optional dependency
pydot_imported_msg = "Install the python package pydot or pydot-ng." pydot_imported_msg = ("Install the python package pydot or pydot-ng."
" Install graphviz.")
_logger = logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论