提交 7b51945e authored 作者: Christof Angermueller's avatar Christof Angermueller

Skip d3viz if requirement missing

上级 04192c23
from theano.d3viz.d3viz import d3viz, d3write from theano.d3viz.d3viz import d3viz, d3write
has_requirements = True
try:
import pydot
except ImportError:
has_requirements = False
...@@ -4,19 +4,15 @@ Author: Christof Angermueller <cangermueller@gmail.com> ...@@ -4,19 +4,15 @@ Author: Christof Angermueller <cangermueller@gmail.com>
""" """
import numpy as np import numpy as np
import logging
import os import os
from functools import reduce from functools import reduce
from six import iteritems, itervalues from six import iteritems, itervalues
pydot_installed = True
try: try:
import pydot as pd import pydot as pd
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported = False
except ImportError: except ImportError:
pydot_imported = False pydot_installed = False
import theano import theano
from theano import gof from theano import gof
...@@ -24,8 +20,6 @@ from theano.compile.profilemode import ProfileMode ...@@ -24,8 +20,6 @@ from theano.compile.profilemode import ProfileMode
from theano.compile import Function from theano.compile import Function
from theano.compile import builders from theano.compile import builders
_logger = logging.getLogger("theano.printing")
class PyDotFormatter(object): class PyDotFormatter(object):
"""Create `pydot` graph object from Theano function. """Create `pydot` graph object from Theano function.
...@@ -47,8 +41,8 @@ class PyDotFormatter(object): ...@@ -47,8 +41,8 @@ class PyDotFormatter(object):
def __init__(self, compact=True): def __init__(self, compact=True):
"""Construct PyDotFormatter object.""" """Construct PyDotFormatter object."""
if not pydot_imported: if not pydot_installed:
raise RuntimeError("Failed to import pydot. Please install pydot!") raise ImportError('Failed to import pydot. Please install pydot!')
self.compact = compact self.compact = compact
self.node_colors = {'input': 'limegreen', self.node_colors = {'input': 'limegreen',
......
from nose.plugins.skip import SkipTest
from theano.d3viz import has_requirements
if not has_requirements:
raise SkipTest('Missing requirements')
import numpy as np import numpy as np
import os.path as pt import os.path as pt
import tempfile import tempfile
......
from nose.plugins.skip import SkipTest
from theano.d3viz import has_requirements
if not has_requirements:
raise SkipTest('Missing requirements')
import numpy as np import numpy as np
import unittest import unittest
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论