提交 7411a082 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

More direct access to special functions

上级 0b94be01
......@@ -9,7 +9,7 @@ from pathlib import Path
from textwrap import dedent
import numpy as np
import scipy.special
from scipy import special
from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented, grad_undefined
......@@ -52,7 +52,7 @@ class Erf(UnaryScalarOp):
nfunc_spec = ("scipy.special.erf", 1, 1)
def impl(self, x):
return scipy.special.erf(x)
return special.erf(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -86,7 +86,7 @@ class Erfc(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfc", 1, 1)
def impl(self, x):
return scipy.special.erfc(x)
return special.erfc(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -113,7 +113,7 @@ class Erfc(UnaryScalarOp):
return f"{z} = erfc(({cast}){x});"
# scipy.special.erfc don't support complex. Why?
# special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name="erfc")
......@@ -135,7 +135,7 @@ class Erfcx(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfcx", 1, 1)
def impl(self, x):
return scipy.special.erfcx(x)
return special.erfcx(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -191,7 +191,7 @@ class Erfinv(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfinv", 1, 1)
def impl(self, x):
return scipy.special.erfinv(x)
return special.erfinv(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -226,7 +226,7 @@ class Erfcinv(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfcinv", 1, 1)
def impl(self, x):
return scipy.special.erfcinv(x)
return special.erfcinv(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -261,7 +261,7 @@ class Owens_t(BinaryScalarOp):
nfunc_spec = ("scipy.special.owens_t", 2, 1)
def impl(self, h, a):
return scipy.special.owens_t(h, a)
return special.owens_t(h, a)
def grad(self, inputs, grads):
(h, a) = inputs
......@@ -286,7 +286,7 @@ class Gamma(UnaryScalarOp):
nfunc_spec = ("scipy.special.gamma", 1, 1)
def impl(self, x):
return scipy.special.gamma(x)
return special.gamma(x)
def L_op(self, inputs, outputs, gout):
(x,) = inputs
......@@ -321,7 +321,7 @@ class GammaLn(UnaryScalarOp):
nfunc_spec = ("scipy.special.gammaln", 1, 1)
def impl(self, x):
return scipy.special.gammaln(x)
return special.gammaln(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -361,7 +361,7 @@ class Psi(UnaryScalarOp):
nfunc_spec = ("scipy.special.psi", 1, 1)
def impl(self, x):
return scipy.special.psi(x)
return special.psi(x)
def L_op(self, inputs, outputs, grads):
(x,) = inputs
......@@ -448,7 +448,7 @@ class TriGamma(UnaryScalarOp):
"""
def impl(self, x):
return scipy.special.polygamma(1, x)
return special.polygamma(1, x)
def L_op(self, inputs, outputs, outputs_gradients):
(x,) = inputs
......@@ -547,7 +547,7 @@ class PolyGamma(BinaryScalarOp):
return upgrade_to_float_no_complex(x_type)
def impl(self, n, x):
return scipy.special.polygamma(n, x)
return special.polygamma(n, x)
def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs
......@@ -574,7 +574,7 @@ class GammaInc(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainc", 2, 1)
def impl(self, k, x):
return scipy.special.gammainc(k, x)
return special.gammainc(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
......@@ -621,7 +621,7 @@ class GammaIncC(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincc", 2, 1)
def impl(self, k, x):
return scipy.special.gammaincc(k, x)
return special.gammaincc(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
......@@ -668,7 +668,7 @@ class GammaIncInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
def impl(self, k, x):
return scipy.special.gammaincinv(k, x)
return special.gammaincinv(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
......@@ -693,7 +693,7 @@ class GammaIncCInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
def impl(self, k, x):
return scipy.special.gammainccinv(k, x)
return special.gammainccinv(k, x)
def grad(self, inputs, grads):
(k, x) = inputs
......@@ -928,7 +928,7 @@ class GammaU(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec.
def impl(self, k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
return special.gammaincc(k, x) * special.gamma(k)
def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
......@@ -960,7 +960,7 @@ class GammaL(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec.
def impl(self, k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
return special.gammainc(k, x) * special.gamma(k)
def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
......@@ -992,7 +992,7 @@ class Jv(BinaryScalarOp):
nfunc_spec = ("scipy.special.jv", 2, 1)
def impl(self, v, x):
return scipy.special.jv(v, x)
return special.jv(v, x)
def grad(self, inputs, grads):
v, x = inputs
......@@ -1017,7 +1017,7 @@ class J1(UnaryScalarOp):
nfunc_spec = ("scipy.special.j1", 1, 1)
def impl(self, x):
return scipy.special.j1(x)
return special.j1(x)
def grad(self, inputs, grads):
(x,) = inputs
......@@ -1044,7 +1044,7 @@ class J0(UnaryScalarOp):
nfunc_spec = ("scipy.special.j0", 1, 1)
def impl(self, x):
return scipy.special.j0(x)
return special.j0(x)
def grad(self, inp, grads):
(x,) = inp
......@@ -1071,7 +1071,7 @@ class Iv(BinaryScalarOp):
nfunc_spec = ("scipy.special.iv", 2, 1)
def impl(self, v, x):
return scipy.special.iv(v, x)
return special.iv(v, x)
def grad(self, inputs, grads):
v, x = inputs
......@@ -1096,7 +1096,7 @@ class I1(UnaryScalarOp):
nfunc_spec = ("scipy.special.i1", 1, 1)
def impl(self, x):
return scipy.special.i1(x)
return special.i1(x)
def grad(self, inputs, grads):
(x,) = inputs
......@@ -1118,7 +1118,7 @@ class I0(UnaryScalarOp):
nfunc_spec = ("scipy.special.i0", 1, 1)
def impl(self, x):
return scipy.special.i0(x)
return special.i0(x)
def grad(self, inp, grads):
(x,) = inp
......@@ -1140,7 +1140,7 @@ class Ive(BinaryScalarOp):
nfunc_spec = ("scipy.special.ive", 2, 1)
def impl(self, v, x):
return scipy.special.ive(v, x)
return special.ive(v, x)
def grad(self, inputs, grads):
v, x = inputs
......@@ -1165,7 +1165,7 @@ class Kve(BinaryScalarOp):
nfunc_spec = ("scipy.special.kve", 2, 1)
def impl(self, v, x):
return scipy.special.kve(v, x)
return special.kve(v, x)
def L_op(self, inputs, outputs, output_grads):
v, x = inputs
......@@ -1195,7 +1195,7 @@ class Sigmoid(UnaryScalarOp):
nfunc_spec = ("scipy.special.expit", 1, 1)
def impl(self, x):
return scipy.special.expit(x)
return special.expit(x)
def grad(self, inp, grads):
(x,) = inp
......@@ -1362,7 +1362,7 @@ class BetaInc(ScalarOp):
nfunc_spec = ("scipy.special.betainc", 3, 1)
def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)
return special.betainc(a, b, x)
def grad(self, inp, grads):
a, b, x = inp
......@@ -1622,7 +1622,7 @@ class BetaIncInv(ScalarOp):
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
def impl(self, a, b, x):
return scipy.special.betaincinv(a, b, x)
return special.betaincinv(a, b, x)
def grad(self, inputs, grads):
(a, b, x) = inputs
......@@ -1661,7 +1661,7 @@ class Hyp2F1(ScalarOp):
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
def impl(self, a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z)
return special.hyp2f1(a, b, c, z)
def grad(self, inputs, grads):
a, b, c, z = inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论