提交 584df0ee authored 作者: TimSalimans's avatar TimSalimans

Add check for scipy version >= 0.12

上级 43c14abb
......@@ -16,6 +16,7 @@ from nose.plugins.attrib import attr
import numpy
from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy.testing.noseclasses import KnownFailureTest
from distutils.version import LooseVersion
import theano
from theano.compat import PY3, exc_message, operator_div
......@@ -56,6 +57,7 @@ mode_no_scipy = get_default_mode()
try:
import scipy.special
import scipy.stats
from scipy import __version__ as scipy_version
imported_scipy_special = True
except ImportError:
if config.mode == "FAST_COMPILE":
......@@ -1642,7 +1644,6 @@ ArctanhInplaceTester = makeBroadcastTester(
if imported_scipy_special:
expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc
expected_erfcx = scipy.special.erfcx
expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv
expected_gamma = scipy.special.gamma
......@@ -1650,6 +1651,12 @@ if imported_scipy_special:
expected_psi = scipy.special.psi
expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype)
skip_scipy = False
if LooseVersion(scipy_version) >= LooseVersion("0.12.0"):
expected_erfcx = scipy.special.erfcx
skip_scipy12 = False
else:
expected_erfcx = []
skip_scipy12 = "the erfcx op requires scipy version >= 0.12, installed version is " + scipy_version
else:
expected_erf = []
expected_erfc = []
......@@ -1661,6 +1668,7 @@ else:
expected_psi = []
expected_chi2sf = []
skip_scipy = "scipy is not present"
skip_scipy12 = "scipy is not present"
ErfTester = makeBroadcastTester(
op=tensor.erf,
......@@ -1705,7 +1713,7 @@ ErfcxTester = makeBroadcastTester(
grad=_grad_broadcast_unary_normal,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
skip=skip_scipy12)
ErfcxInplaceTester = makeBroadcastTester(
op=inplace.erfcx_inplace,
expected=expected_erfcx,
......@@ -1714,7 +1722,7 @@ ErfcxInplaceTester = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
skip=skip_scipy12)
ErfinvTester = makeBroadcastTester(
op=tensor.erfinv,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论