提交 7411a082 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

More direct access to special functions

上级 0b94be01
...@@ -9,7 +9,7 @@ from pathlib import Path ...@@ -9,7 +9,7 @@ from pathlib import Path
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
import scipy.special from scipy import special
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented, grad_undefined from pytensor.gradient import grad_not_implemented, grad_undefined
...@@ -52,7 +52,7 @@ class Erf(UnaryScalarOp): ...@@ -52,7 +52,7 @@ class Erf(UnaryScalarOp):
nfunc_spec = ("scipy.special.erf", 1, 1) nfunc_spec = ("scipy.special.erf", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.erf(x) return special.erf(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -86,7 +86,7 @@ class Erfc(UnaryScalarOp): ...@@ -86,7 +86,7 @@ class Erfc(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfc", 1, 1) nfunc_spec = ("scipy.special.erfc", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.erfc(x) return special.erfc(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -113,7 +113,7 @@ class Erfc(UnaryScalarOp): ...@@ -113,7 +113,7 @@ class Erfc(UnaryScalarOp):
return f"{z} = erfc(({cast}){x});" return f"{z} = erfc(({cast}){x});"
# scipy.special.erfc don't support complex. Why? # special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name="erfc") erfc = Erfc(upgrade_to_float_no_complex, name="erfc")
...@@ -135,7 +135,7 @@ class Erfcx(UnaryScalarOp): ...@@ -135,7 +135,7 @@ class Erfcx(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfcx", 1, 1) nfunc_spec = ("scipy.special.erfcx", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.erfcx(x) return special.erfcx(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -191,7 +191,7 @@ class Erfinv(UnaryScalarOp): ...@@ -191,7 +191,7 @@ class Erfinv(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfinv", 1, 1) nfunc_spec = ("scipy.special.erfinv", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.erfinv(x) return special.erfinv(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -226,7 +226,7 @@ class Erfcinv(UnaryScalarOp): ...@@ -226,7 +226,7 @@ class Erfcinv(UnaryScalarOp):
nfunc_spec = ("scipy.special.erfcinv", 1, 1) nfunc_spec = ("scipy.special.erfcinv", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.erfcinv(x) return special.erfcinv(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -261,7 +261,7 @@ class Owens_t(BinaryScalarOp): ...@@ -261,7 +261,7 @@ class Owens_t(BinaryScalarOp):
nfunc_spec = ("scipy.special.owens_t", 2, 1) nfunc_spec = ("scipy.special.owens_t", 2, 1)
def impl(self, h, a): def impl(self, h, a):
return scipy.special.owens_t(h, a) return special.owens_t(h, a)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(h, a) = inputs (h, a) = inputs
...@@ -286,7 +286,7 @@ class Gamma(UnaryScalarOp): ...@@ -286,7 +286,7 @@ class Gamma(UnaryScalarOp):
nfunc_spec = ("scipy.special.gamma", 1, 1) nfunc_spec = ("scipy.special.gamma", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.gamma(x) return special.gamma(x)
def L_op(self, inputs, outputs, gout): def L_op(self, inputs, outputs, gout):
(x,) = inputs (x,) = inputs
...@@ -321,7 +321,7 @@ class GammaLn(UnaryScalarOp): ...@@ -321,7 +321,7 @@ class GammaLn(UnaryScalarOp):
nfunc_spec = ("scipy.special.gammaln", 1, 1) nfunc_spec = ("scipy.special.gammaln", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.gammaln(x) return special.gammaln(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -361,7 +361,7 @@ class Psi(UnaryScalarOp): ...@@ -361,7 +361,7 @@ class Psi(UnaryScalarOp):
nfunc_spec = ("scipy.special.psi", 1, 1) nfunc_spec = ("scipy.special.psi", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.psi(x) return special.psi(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -448,7 +448,7 @@ class TriGamma(UnaryScalarOp): ...@@ -448,7 +448,7 @@ class TriGamma(UnaryScalarOp):
""" """
def impl(self, x): def impl(self, x):
return scipy.special.polygamma(1, x) return special.polygamma(1, x)
def L_op(self, inputs, outputs, outputs_gradients): def L_op(self, inputs, outputs, outputs_gradients):
(x,) = inputs (x,) = inputs
...@@ -547,7 +547,7 @@ class PolyGamma(BinaryScalarOp): ...@@ -547,7 +547,7 @@ class PolyGamma(BinaryScalarOp):
return upgrade_to_float_no_complex(x_type) return upgrade_to_float_no_complex(x_type)
def impl(self, n, x): def impl(self, n, x):
return scipy.special.polygamma(n, x) return special.polygamma(n, x)
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs (n, x) = inputs
...@@ -574,7 +574,7 @@ class GammaInc(BinaryScalarOp): ...@@ -574,7 +574,7 @@ class GammaInc(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainc", 2, 1) nfunc_spec = ("scipy.special.gammainc", 2, 1)
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammainc(k, x) return special.gammainc(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -621,7 +621,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -621,7 +621,7 @@ class GammaIncC(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincc", 2, 1) nfunc_spec = ("scipy.special.gammaincc", 2, 1)
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammaincc(k, x) return special.gammaincc(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -668,7 +668,7 @@ class GammaIncInv(BinaryScalarOp): ...@@ -668,7 +668,7 @@ class GammaIncInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincinv", 2, 1) nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammaincinv(k, x) return special.gammaincinv(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -693,7 +693,7 @@ class GammaIncCInv(BinaryScalarOp): ...@@ -693,7 +693,7 @@ class GammaIncCInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainccinv", 2, 1) nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammainccinv(k, x) return special.gammainccinv(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -928,7 +928,7 @@ class GammaU(BinaryScalarOp): ...@@ -928,7 +928,7 @@ class GammaU(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec. # Note there is no basic SciPy version so no nfunc_spec.
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) return special.gammaincc(k, x) * special.gamma(k)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
...@@ -960,7 +960,7 @@ class GammaL(BinaryScalarOp): ...@@ -960,7 +960,7 @@ class GammaL(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec. # Note there is no basic SciPy version so no nfunc_spec.
def impl(self, k, x): def impl(self, k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k) return special.gammainc(k, x) * special.gamma(k)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
...@@ -992,7 +992,7 @@ class Jv(BinaryScalarOp): ...@@ -992,7 +992,7 @@ class Jv(BinaryScalarOp):
nfunc_spec = ("scipy.special.jv", 2, 1) nfunc_spec = ("scipy.special.jv", 2, 1)
def impl(self, v, x): def impl(self, v, x):
return scipy.special.jv(v, x) return special.jv(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1017,7 +1017,7 @@ class J1(UnaryScalarOp): ...@@ -1017,7 +1017,7 @@ class J1(UnaryScalarOp):
nfunc_spec = ("scipy.special.j1", 1, 1) nfunc_spec = ("scipy.special.j1", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.j1(x) return special.j1(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -1044,7 +1044,7 @@ class J0(UnaryScalarOp): ...@@ -1044,7 +1044,7 @@ class J0(UnaryScalarOp):
nfunc_spec = ("scipy.special.j0", 1, 1) nfunc_spec = ("scipy.special.j0", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.j0(x) return special.j0(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -1071,7 +1071,7 @@ class Iv(BinaryScalarOp): ...@@ -1071,7 +1071,7 @@ class Iv(BinaryScalarOp):
nfunc_spec = ("scipy.special.iv", 2, 1) nfunc_spec = ("scipy.special.iv", 2, 1)
def impl(self, v, x): def impl(self, v, x):
return scipy.special.iv(v, x) return special.iv(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1096,7 +1096,7 @@ class I1(UnaryScalarOp): ...@@ -1096,7 +1096,7 @@ class I1(UnaryScalarOp):
nfunc_spec = ("scipy.special.i1", 1, 1) nfunc_spec = ("scipy.special.i1", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.i1(x) return special.i1(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -1118,7 +1118,7 @@ class I0(UnaryScalarOp): ...@@ -1118,7 +1118,7 @@ class I0(UnaryScalarOp):
nfunc_spec = ("scipy.special.i0", 1, 1) nfunc_spec = ("scipy.special.i0", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.i0(x) return special.i0(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -1140,7 +1140,7 @@ class Ive(BinaryScalarOp): ...@@ -1140,7 +1140,7 @@ class Ive(BinaryScalarOp):
nfunc_spec = ("scipy.special.ive", 2, 1) nfunc_spec = ("scipy.special.ive", 2, 1)
def impl(self, v, x): def impl(self, v, x):
return scipy.special.ive(v, x) return special.ive(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1165,7 +1165,7 @@ class Kve(BinaryScalarOp): ...@@ -1165,7 +1165,7 @@ class Kve(BinaryScalarOp):
nfunc_spec = ("scipy.special.kve", 2, 1) nfunc_spec = ("scipy.special.kve", 2, 1)
def impl(self, v, x): def impl(self, v, x):
return scipy.special.kve(v, x) return special.kve(v, x)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
v, x = inputs v, x = inputs
...@@ -1195,7 +1195,7 @@ class Sigmoid(UnaryScalarOp): ...@@ -1195,7 +1195,7 @@ class Sigmoid(UnaryScalarOp):
nfunc_spec = ("scipy.special.expit", 1, 1) nfunc_spec = ("scipy.special.expit", 1, 1)
def impl(self, x): def impl(self, x):
return scipy.special.expit(x) return special.expit(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -1362,7 +1362,7 @@ class BetaInc(ScalarOp): ...@@ -1362,7 +1362,7 @@ class BetaInc(ScalarOp):
nfunc_spec = ("scipy.special.betainc", 3, 1) nfunc_spec = ("scipy.special.betainc", 3, 1)
def impl(self, a, b, x): def impl(self, a, b, x):
return scipy.special.betainc(a, b, x) return special.betainc(a, b, x)
def grad(self, inp, grads): def grad(self, inp, grads):
a, b, x = inp a, b, x = inp
...@@ -1622,7 +1622,7 @@ class BetaIncInv(ScalarOp): ...@@ -1622,7 +1622,7 @@ class BetaIncInv(ScalarOp):
nfunc_spec = ("scipy.special.betaincinv", 3, 1) nfunc_spec = ("scipy.special.betaincinv", 3, 1)
def impl(self, a, b, x): def impl(self, a, b, x):
return scipy.special.betaincinv(a, b, x) return special.betaincinv(a, b, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(a, b, x) = inputs (a, b, x) = inputs
...@@ -1661,7 +1661,7 @@ class Hyp2F1(ScalarOp): ...@@ -1661,7 +1661,7 @@ class Hyp2F1(ScalarOp):
nfunc_spec = ("scipy.special.hyp2f1", 4, 1) nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
def impl(self, a, b, c, z): def impl(self, a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z) return special.hyp2f1(a, b, c, z)
def grad(self, inputs, grads): def grad(self, inputs, grads):
a, b, c, z = inputs a, b, c, z = inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论