提交 584df0ee authored 作者: TimSalimans's avatar TimSalimans

Add check for scipy version >= 0.12

上级 43c14abb
...@@ -16,6 +16,7 @@ from nose.plugins.attrib import attr ...@@ -16,6 +16,7 @@ from nose.plugins.attrib import attr
import numpy import numpy
from numpy.testing import dec, assert_array_equal, assert_allclose from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
from distutils.version import LooseVersion
import theano import theano
from theano.compat import PY3, exc_message, operator_div from theano.compat import PY3, exc_message, operator_div
...@@ -56,6 +57,7 @@ mode_no_scipy = get_default_mode() ...@@ -56,6 +57,7 @@ mode_no_scipy = get_default_mode()
try: try:
import scipy.special import scipy.special
import scipy.stats import scipy.stats
from scipy import __version__ as scipy_version
imported_scipy_special = True imported_scipy_special = True
except ImportError: except ImportError:
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -1642,7 +1644,6 @@ ArctanhInplaceTester = makeBroadcastTester( ...@@ -1642,7 +1644,6 @@ ArctanhInplaceTester = makeBroadcastTester(
if imported_scipy_special: if imported_scipy_special:
expected_erf = scipy.special.erf expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc expected_erfc = scipy.special.erfc
expected_erfcx = scipy.special.erfcx
expected_erfinv = scipy.special.erfinv expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv expected_erfcinv = scipy.special.erfcinv
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
...@@ -1650,6 +1651,12 @@ if imported_scipy_special: ...@@ -1650,6 +1651,12 @@ if imported_scipy_special:
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype) expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype)
skip_scipy = False skip_scipy = False
if LooseVersion(scipy_version) >= LooseVersion("0.12.0"):
expected_erfcx = scipy.special.erfcx
skip_scipy12 = False
else:
expected_erfcx = []
skip_scipy12 = "the erfcx op requires scipy version >= 0.12, installed version is " + scipy_version
else: else:
expected_erf = [] expected_erf = []
expected_erfc = [] expected_erfc = []
...@@ -1661,6 +1668,7 @@ else: ...@@ -1661,6 +1668,7 @@ else:
expected_psi = [] expected_psi = []
expected_chi2sf = [] expected_chi2sf = []
skip_scipy = "scipy is not present" skip_scipy = "scipy is not present"
skip_scipy12 = "scipy is not present"
ErfTester = makeBroadcastTester( ErfTester = makeBroadcastTester(
op=tensor.erf, op=tensor.erf,
...@@ -1705,7 +1713,7 @@ ErfcxTester = makeBroadcastTester( ...@@ -1705,7 +1713,7 @@ ErfcxTester = makeBroadcastTester(
grad=_grad_broadcast_unary_normal, grad=_grad_broadcast_unary_normal,
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
skip=skip_scipy) skip=skip_scipy12)
ErfcxInplaceTester = makeBroadcastTester( ErfcxInplaceTester = makeBroadcastTester(
op=inplace.erfcx_inplace, op=inplace.erfcx_inplace,
expected=expected_erfcx, expected=expected_erfcx,
...@@ -1714,7 +1722,7 @@ ErfcxInplaceTester = makeBroadcastTester( ...@@ -1714,7 +1722,7 @@ ErfcxInplaceTester = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy12)
ErfinvTester = makeBroadcastTester( ErfinvTester = makeBroadcastTester(
op=tensor.erfinv, op=tensor.erfinv,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论