提交 e5de51da authored 作者: Ethan Buchman's avatar Ethan Buchman

added scipy.stats.chi2.sf op

上级 0845ddc3
...@@ -3260,6 +3260,11 @@ def gammaln(a): ...@@ -3260,6 +3260,11 @@ def gammaln(a):
def psi(a): def psi(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_elemwise
def chi2sf(x, k):
"""chi squared survival function"""
@_scal_elemwise_with_nfunc('real', 1, -1) @_scal_elemwise_with_nfunc('real', 1, -1)
def real(z): def real(z):
......
...@@ -229,6 +229,10 @@ def gammaln_inplace(a): ...@@ -229,6 +229,10 @@ def gammaln_inplace(a):
def psi_inplace(a): def psi_inplace(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
@_scal_inplace
def chi2sf_inplace(x, k):
"""chi squared survival function"""
@_scal_inplace @_scal_inplace
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -52,6 +52,7 @@ imported_scipy_special = False ...@@ -52,6 +52,7 @@ imported_scipy_special = False
mode_no_scipy = get_default_mode() mode_no_scipy = get_default_mode()
try: try:
import scipy.special import scipy.special
import scipy.stats
imported_scipy_special = True imported_scipy_special = True
except ImportError: except ImportError:
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -1419,6 +1420,7 @@ if imported_scipy_special: ...@@ -1419,6 +1420,7 @@ if imported_scipy_special:
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_chi2sf = scipy.stats.chi2.sf
skip_scipy = False skip_scipy = False
else: else:
expected_erf = [] expected_erf = []
...@@ -1428,6 +1430,7 @@ else: ...@@ -1428,6 +1430,7 @@ else:
expected_gamma = [] expected_gamma = []
expected_gammaln = [] expected_gammaln = []
expected_psi = [] expected_psi = []
expected_chi2sf = []
skip_scipy = "scipy is not present" skip_scipy = "scipy is not present"
ErfTester = makeBroadcastTester( ErfTester = makeBroadcastTester(
...@@ -1547,6 +1550,32 @@ PsiInplaceTester = makeBroadcastTester( ...@@ -1547,6 +1550,32 @@ PsiInplaceTester = makeBroadcastTester(
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
'''
#chi2sf takes two inputs, a value (x) and a degrees of freedom (k).
# not sure how to deal with that here...
_good_broadcast_unary_psi = dict(
normal=(rand_ranged(1, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
Chi2SFTester = makeBroadcastTester(
op=tensor.chi2sf,
expected=expected_chi2sf,
good=_good_broadcast_unary_chi2sf,
eps=2e-10,
mode=mode_no_scipy,
skip=skip_scipy)
Chi2SFInplaceTester = makeBroadcastTester(
op=inplace.chi2sf_inplace,
expected=expected_chi2sf,
good=_good_broadcast_unary_chi2sf,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
skip=skip_scipy)
'''
ZerosLikeTester = makeBroadcastTester( ZerosLikeTester = makeBroadcastTester(
op=tensor.zeros_like, op=tensor.zeros_like,
expected=numpy.zeros_like, expected=numpy.zeros_like,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论