提交 80b0d16a authored 作者: Frederic Bastien's avatar Frederic Bastien

Make a unit file to test if scipy is available.

上级 b7a4cd25
...@@ -458,7 +458,9 @@ def diag(x): ...@@ -458,7 +458,9 @@ def diag(x):
raise TypeError('diag requires vector or matrix argument', x) raise TypeError('diag requires vector or matrix argument', x)
class Det(Op): class Det(Op):
"""matrix determinant""" """matrix determinant
TODO: move this op to another file that request scipy.
"""
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
o = theano.tensor.scalar(dtype=x.dtype) o = theano.tensor.scalar(dtype=x.dtype)
......
import numpy import numpy
import theano import theano
import theano.sparse # To know if scipy is available. import theano.scipy # To know if scipy is available.
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
...@@ -64,9 +64,7 @@ def test_inverse_grad(): ...@@ -64,9 +64,7 @@ def test_inverse_grad():
def test_det_grad(): def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it. # If scipy is not available, this test will fail, thus we skip it.
# Note that currently we re-use the `enable_sparse` flag, but it may be if not theano.scipy.scipy_available:
# cleaner to have a different `scipy_available` flag in the future.
if not theano.sparse.enable_sparse:
raise SkipTest('Scipy is not available') raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(1234) rng = numpy.random.RandomState(1234)
......
import sys import sys
import theano.scipy
enable_sparse=True enable_sparse = False
if theano.scipy.scipy_available:
try:
import scipy import scipy
if scipy.__version__ < '0.7': if not scipy.__version__ < '0.7':
enable_sparse = True
else:
sys.stderr.write("WARNING: scipy version = %s. We request version >=0.7.0 for the sparse code as it has bugs fixed in the sparse matrix code.\n" % scipy.__version__) sys.stderr.write("WARNING: scipy version = %s. We request version >=0.7.0 for the sparse code as it has bugs fixed in the sparse matrix code.\n" % scipy.__version__)
enable_sparse=False else:
except ImportError:
enable_sparse=False
sys.stderr.write("WARNING: scipy can't be imported. We disable the sparse matrix code.") sys.stderr.write("WARNING: scipy can't be imported. We disable the sparse matrix code.")
if enable_sparse: if enable_sparse:
from basic import * from basic import *
import sharedvar import sharedvar
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论