提交 2dc75912 authored 作者: Rune Michael Dominik's avatar Rune Michael Dominik 提交者: Brandon T. Willard

Add scipy's owens_t function as op

上级 fd50f36b
...@@ -250,6 +250,35 @@ class Erfcinv(UnaryScalarOp): ...@@ -250,6 +250,35 @@ class Erfcinv(UnaryScalarOp):
erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv") erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv")
class Owens_t(BinaryScalarOp):
nfunc_spec = ("scipy.special.owens_t", 2, 1)
@staticmethod
def st_impl(h, a):
return scipy.special.owens_t(h, a)
def impl(self, h, a):
return Owens_t.st_impl(h, a)
def grad(self, inputs, grads):
(h, a) = inputs
(gz,) = grads
return [
gz
* (-1)
* exp(-(h**2) / 2)
* erf(a * h / np.sqrt(2))
/ (2 * np.sqrt(2 * np.pi)),
gz * exp(-0.5 * (a**2 + 1) * h**2) / (2 * np.pi * (a**2 + 1)),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
owens_t = Owens_t(upgrade_to_float, name="owens_t")
class Gamma(UnaryScalarOp): class Gamma(UnaryScalarOp):
nfunc_spec = ("scipy.special.gamma", 1, 1) nfunc_spec = ("scipy.special.gamma", 1, 1)
......
...@@ -233,6 +233,11 @@ def erfcx_inplace(a): ...@@ -233,6 +233,11 @@ def erfcx_inplace(a):
"""scaled complementary error function""" """scaled complementary error function"""
@scalar_elemwise
def owens_t_inplace(h, a):
"""owens t function"""
@scalar_elemwise @scalar_elemwise
def gamma_inplace(a): def gamma_inplace(a):
"""gamma function""" """gamma function"""
......
...@@ -1339,6 +1339,11 @@ def erfcinv(a): ...@@ -1339,6 +1339,11 @@ def erfcinv(a):
"""inverse complementary error function""" """inverse complementary error function"""
@scalar_elemwise
def owens_t(h, a):
"""owens t function"""
@scalar_elemwise @scalar_elemwise
def gamma(a): def gamma(a):
"""gamma function""" """gamma function"""
...@@ -3062,6 +3067,7 @@ __all__ = [ ...@@ -3062,6 +3067,7 @@ __all__ = [
"erfcx", "erfcx",
"erfinv", "erfinv",
"erfcinv", "erfcinv",
"owens_t",
"gamma", "gamma",
"gammaln", "gammaln",
"psi", "psi",
......
...@@ -53,6 +53,7 @@ expected_erf = scipy.special.erf ...@@ -53,6 +53,7 @@ expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc expected_erfc = scipy.special.erfc
expected_erfinv = scipy.special.erfinv expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv expected_erfcinv = scipy.special.erfcinv
expected_owenst = scipy.special.owens_t
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
...@@ -146,6 +147,55 @@ TestErfcinvBroadcast = makeBroadcastTester( ...@@ -146,6 +147,55 @@ TestErfcinvBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_binary_owenst = dict(
normal=(
random_ranged(-5, 5, (2, 3), rng=rng),
random_ranged(-5, 5, (2, 3), rng=rng),
),
empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)),
int=(
integers_ranged(-5, 5, (2, 3), rng=rng),
integers_ranged(-5, 5, (2, 3), rng=rng),
),
uint8=(
integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"),
integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"),
),
uint16=(
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"),
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"),
),
uint64=(
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"),
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"),
),
)
_grad_broadcast_binary_owenst = dict(
normal=(
random_ranged(-5, 5, (2, 3), rng=rng),
random_ranged(-5, 5, (2, 3), rng=rng),
)
)
TestOwensTBroadcast = makeBroadcastTester(
op=at.owens_t,
expected=expected_owenst,
good=_good_broadcast_binary_owenst,
grad=_grad_broadcast_binary_owenst,
eps=2e-10,
mode=mode_no_scipy,
)
TestOwensTInplaceBroadcast = makeBroadcastTester(
op=inplace.owens_t_inplace,
expected=expected_owenst,
good=_good_broadcast_binary_owenst,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_gammaln = dict( _good_broadcast_unary_gammaln = dict(
normal=(random_ranged(-1 + 1e-2, 10, (2, 3), rng=rng),), normal=(random_ranged(-1 + 1e-2, 10, (2, 3), rng=rng),),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论