提交 0b07727b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused ScalarOp.st_impl

上级 60c2d925
...@@ -10,7 +10,6 @@ from textwrap import dedent ...@@ -10,7 +10,6 @@ from textwrap import dedent
import numpy as np import numpy as np
import scipy.special import scipy.special
import scipy.stats
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented, grad_undefined from pytensor.gradient import grad_not_implemented, grad_undefined
...@@ -261,12 +260,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv") ...@@ -261,12 +260,8 @@ erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv")
class Owens_t(BinaryScalarOp): class Owens_t(BinaryScalarOp):
nfunc_spec = ("scipy.special.owens_t", 2, 1) nfunc_spec = ("scipy.special.owens_t", 2, 1)
@staticmethod
def st_impl(h, a):
return scipy.special.owens_t(h, a)
def impl(self, h, a): def impl(self, h, a):
return Owens_t.st_impl(h, a) return scipy.special.owens_t(h, a)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(h, a) = inputs (h, a) = inputs
...@@ -290,12 +285,8 @@ owens_t = Owens_t(upgrade_to_float, name="owens_t") ...@@ -290,12 +285,8 @@ owens_t = Owens_t(upgrade_to_float, name="owens_t")
class Gamma(UnaryScalarOp): class Gamma(UnaryScalarOp):
nfunc_spec = ("scipy.special.gamma", 1, 1) nfunc_spec = ("scipy.special.gamma", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gamma(x)
def impl(self, x): def impl(self, x):
return Gamma.st_impl(x) return scipy.special.gamma(x)
def L_op(self, inputs, outputs, gout): def L_op(self, inputs, outputs, gout):
(x,) = inputs (x,) = inputs
...@@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp): ...@@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp):
nfunc_spec = ("scipy.special.gammaln", 1, 1) nfunc_spec = ("scipy.special.gammaln", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.gammaln(x)
def impl(self, x): def impl(self, x):
return GammaLn.st_impl(x) return scipy.special.gammaln(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -373,12 +360,8 @@ class Psi(UnaryScalarOp): ...@@ -373,12 +360,8 @@ class Psi(UnaryScalarOp):
nfunc_spec = ("scipy.special.psi", 1, 1) nfunc_spec = ("scipy.special.psi", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.psi(x)
def impl(self, x): def impl(self, x):
return Psi.st_impl(x) return scipy.special.psi(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp): ...@@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp):
""" """
@staticmethod
def st_impl(x):
return scipy.special.polygamma(1, x)
def impl(self, x): def impl(self, x):
return TriGamma.st_impl(x) return scipy.special.polygamma(1, x)
def L_op(self, inputs, outputs, outputs_gradients): def L_op(self, inputs, outputs, outputs_gradients):
(x,) = inputs (x,) = inputs
...@@ -567,12 +546,8 @@ class PolyGamma(BinaryScalarOp): ...@@ -567,12 +546,8 @@ class PolyGamma(BinaryScalarOp):
# Scipy doesn't support it # Scipy doesn't support it
return upgrade_to_float_no_complex(x_type) return upgrade_to_float_no_complex(x_type)
@staticmethod
def st_impl(n, x):
return scipy.special.polygamma(n, x)
def impl(self, n, x): def impl(self, n, x):
return PolyGamma.st_impl(n, x) return scipy.special.polygamma(n, x)
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs (n, x) = inputs
...@@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp): ...@@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainc", 2, 1) nfunc_spec = ("scipy.special.gammainc", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammainc(k, x)
def impl(self, k, x): def impl(self, k, x):
return GammaInc.st_impl(k, x) return scipy.special.gammainc(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp): ...@@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincc", 2, 1) nfunc_spec = ("scipy.special.gammaincc", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincc(k, x)
def impl(self, k, x): def impl(self, k, x):
return GammaIncC.st_impl(k, x) return scipy.special.gammaincc(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp): ...@@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammaincinv", 2, 1) nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincinv(k, x)
def impl(self, k, x): def impl(self, k, x):
return GammaIncInv.st_impl(k, x) return scipy.special.gammaincinv(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp): ...@@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp):
nfunc_spec = ("scipy.special.gammainccinv", 2, 1) nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
@staticmethod
def st_impl(k, x):
return scipy.special.gammainccinv(k, x)
def impl(self, k, x): def impl(self, k, x):
return GammaIncCInv.st_impl(k, x) return scipy.special.gammainccinv(k, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(k, x) = inputs (k, x) = inputs
...@@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp): ...@@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec. # Note there is no basic SciPy version so no nfunc_spec.
@staticmethod
def st_impl(k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
def impl(self, k, x): def impl(self, k, x):
return GammaU.st_impl(k, x) return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
...@@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp): ...@@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp):
# Note there is no basic SciPy version so no nfunc_spec. # Note there is no basic SciPy version so no nfunc_spec.
@staticmethod
def st_impl(k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
def impl(self, k, x): def impl(self, k, x):
return GammaL.st_impl(k, x) return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
...@@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp): ...@@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp):
nfunc_spec = ("scipy.special.jv", 2, 1) nfunc_spec = ("scipy.special.jv", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.jv(v, x)
def impl(self, v, x): def impl(self, v, x):
return self.st_impl(v, x) return scipy.special.jv(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp): ...@@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp):
nfunc_spec = ("scipy.special.j1", 1, 1) nfunc_spec = ("scipy.special.j1", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.j1(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) return scipy.special.j1(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp): ...@@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp):
nfunc_spec = ("scipy.special.j0", 1, 1) nfunc_spec = ("scipy.special.j0", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.j0(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) return scipy.special.j0(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp): ...@@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp):
nfunc_spec = ("scipy.special.iv", 2, 1) nfunc_spec = ("scipy.special.iv", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.iv(v, x)
def impl(self, v, x): def impl(self, v, x):
return self.st_impl(v, x) return scipy.special.iv(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp): ...@@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp):
nfunc_spec = ("scipy.special.i1", 1, 1) nfunc_spec = ("scipy.special.i1", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.i1(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) return scipy.special.i1(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp): ...@@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp):
nfunc_spec = ("scipy.special.i0", 1, 1) nfunc_spec = ("scipy.special.i0", 1, 1)
@staticmethod
def st_impl(x):
return scipy.special.i0(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) return scipy.special.i0(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp): ...@@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp):
nfunc_spec = ("scipy.special.ive", 2, 1) nfunc_spec = ("scipy.special.ive", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.ive(v, x)
def impl(self, v, x): def impl(self, v, x):
return self.st_impl(v, x) return scipy.special.ive(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp): ...@@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp):
nfunc_spec = ("scipy.special.kve", 2, 1) nfunc_spec = ("scipy.special.kve", 2, 1)
@staticmethod
def st_impl(v, x):
return scipy.special.kve(v, x)
def impl(self, v, x): def impl(self, v, x):
return self.st_impl(v, x) return scipy.special.kve(v, x)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
v, x = inputs v, x = inputs
...@@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp): ...@@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp):
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
""" """
@staticmethod def impl(self, x):
def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in # If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8") not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
...@@ -1343,9 +1261,6 @@ class Softplus(UnaryScalarOp): ...@@ -1343,9 +1261,6 @@ class Softplus(UnaryScalarOp):
else: else:
return x return x
def impl(self, x):
return Softplus.static_impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
...@@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp): ...@@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp):
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
""" """
@staticmethod def impl(self, x):
def static_impl(x):
if x < np.log(0.5): if x < np.log(0.5):
return np.log1p(-np.exp(x)) return np.log1p(-np.exp(x))
else: else:
return np.log(-np.expm1(x)) return np.log(-np.expm1(x))
def impl(self, x):
return Log1mexp.static_impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
...@@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp): ...@@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp):
nin = 4 nin = 4
nfunc_spec = ("scipy.special.hyp2f1", 4, 1) nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
@staticmethod
def st_impl(a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z)
def impl(self, a, b, c, z): def impl(self, a, b, c, z):
return Hyp2F1.st_impl(a, b, c, z) return scipy.special.hyp2f1(a, b, c, z)
def grad(self, inputs, grads): def grad(self, inputs, grads):
a, b, c, z = inputs a, b, c, z = inputs
......
...@@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp): ...@@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp):
""" """
@staticmethod def impl(self, x):
def st_impl(x):
if x == 0.0: if x == 0.0:
return 0.0 return 0.0
return x * np.log(x) return x * np.log(x)
def impl(self, x):
return XlogX.st_impl(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
(gz,) = grads (gz,) = grads
...@@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp): ...@@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp):
""" """
@staticmethod def impl(self, x, y):
def st_impl(x, y):
if x == 0.0: if x == 0.0:
return 0.0 return 0.0
return x * np.log(y) return x * np.log(y)
def impl(self, x, y):
return XlogY0.st_impl(x, y)
def grad(self, inputs, grads): def grad(self, inputs, grads):
x, y = inputs x, y = inputs
(gz,) = grads (gz,) = grads
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论