提交 92d5450f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement Polygamma Op

上级 f6be5213
......@@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""):
return (
NullType(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
f"input {x_pos} ({x}) of the {op} op is undefined. {comment}"
)
)()
......
......@@ -13,7 +13,7 @@ import scipy.special
import scipy.stats
from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented
from pytensor.gradient import grad_not_implemented, grad_undefined
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import (
......@@ -473,8 +473,12 @@ class TriGamma(UnaryScalarOp):
def impl(self, x):
return TriGamma.st_impl(x)
def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
def L_op(self, inputs, outputs, outputs_gradients):
(x,) = inputs
(g_out,) = outputs_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [g_out * polygamma(2, x)]
def c_support_code(self, **kwargs):
# The implementation has been copied from
......@@ -541,7 +545,52 @@ class TriGamma(UnaryScalarOp):
raise NotImplementedError("only floating point is implemented")
tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma")
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
class PolyGamma(BinaryScalarOp):
"""Polygamma function of order n evaluated at x.
It corresponds to the (n+1)th derivative of the log gamma function.
TODO: Because the first input is discrete and the output is continuous,
the default elemwise inplace won't work, as it always tries to store the results in the first input.
"""
nfunc_spec = ("scipy.special.polygamma", 2, 1)
@staticmethod
def output_types_preference(n_type, x_type):
if n_type not in discrete_types:
raise TypeError(
f"Polygamma order parameter must be discrete, got {n_type} dtype"
)
# Scipy doesn't support it
return upgrade_to_float_no_complex(x_type)
@staticmethod
def st_impl(n, x):
return scipy.special.polygamma(n, x)
def impl(self, n, x):
return PolyGamma.st_impl(n, x)
def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs
(g_out,) = output_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [
grad_undefined(self, 0, n),
g_out * self(n + 1, x),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
polygamma = PolyGamma(name="polygamma")
class Chi2SF(BinaryScalarOp):
......
......@@ -1369,6 +1369,11 @@ def tri_gamma(a):
"""second derivative of the log gamma function"""
@scalar_elemwise
def polygamma(n, x):
"""Polygamma function of order n evaluated at x"""
@scalar_elemwise
def chi2sf(x, k):
"""chi squared survival function"""
......@@ -3008,6 +3013,7 @@ __all__ = [
"psi",
"digamma",
"tri_gamma",
"polygamma",
"chi2sf",
"gammainc",
"gammaincc",
......
......@@ -52,6 +52,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import (
add,
digamma,
dot,
eq,
erf,
......@@ -68,7 +69,7 @@ from pytensor.tensor.math import (
makeKeepDims,
)
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum, mul, neg
from pytensor.tensor.math import maximum, mul, neg, polygamma
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import (
prod,
......@@ -81,7 +82,7 @@ from pytensor.tensor.math import (
sub,
)
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.math import tri_gamma, true_div
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
......@@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node):
x = node.inputs[0]
if x.type.dtype not in complex_dtypes:
return [x]
local_polygamma_to_digamma = PatternNodeRewriter(
(polygamma, 0, "x"),
(digamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_digamma",
)
register_specialize(local_polygamma_to_digamma)
local_polygamma_to_tri_gamma = PatternNodeRewriter(
(polygamma, 1, "x"),
(tri_gamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_tri_gamma",
)
register_specialize(local_polygamma_to_tri_gamma)
......@@ -20,6 +20,7 @@ from pytensor.tensor.math import (
iv,
log,
log1mexp,
polygamma,
psi,
sigmoid,
softplus,
......@@ -178,6 +179,20 @@ def test_tri_gamma():
compare_jax_and_py(fg, [np.array([3.0, 5.0])])
def test_polygamma():
n = vector("n", dtype="int32")
x = vector("x", dtype="float32")
out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py(
fg,
[
np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"),
],
)
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
......
......@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
......@@ -69,7 +69,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import minimum, mul, neg, neq, polygamma
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import (
prod,
......@@ -4236,3 +4236,19 @@ def test_logdiffexp():
np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
)
def test_polygamma_specialization():
x = vector("x")
y1 = polygamma(0, x)
y2 = polygamma(1, x)
y3 = polygamma(2, x)
fn = pytensor.function(
[x], [y1, y2, y3], mode=get_default_mode().including("specialize")
)
fn_outs = fn.maker.fgraph.outputs
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
......@@ -7,6 +7,7 @@ from itertools import product
import numpy as np
import pytest
import scipy.special
from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp
......@@ -64,6 +65,7 @@ from pytensor.tensor.math import (
cov,
deg2rad,
dense_dot,
digamma,
dot,
eq,
exp,
......@@ -93,6 +95,7 @@ from pytensor.tensor.math import (
neg,
neq,
outer,
polygamma,
power,
ptp,
rad2deg,
......@@ -3470,3 +3473,44 @@ class TestMatMul:
fn = function([x, y], x @ y, mode="FAST_RUN")
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, Dot22)
class TestPolyGamma:
def test_basic(self):
n = vector("n", dtype="int64")
x = scalar("x")
np.testing.assert_allclose(
polygamma(n, x).eval({n: [0, 1], x: 0.5}),
scipy.special.polygamma([0, 1], 0.5),
)
def test_continuous_n_raises(self):
n = scalar("n", dtype="float64")
with pytest.raises(TypeError, match="must be discrete"):
polygamma(n, 0.5)
def test_complex_x_raises(self):
x = scalar(dtype="complex128")
with pytest.raises(TypeError, match="complex argument not supported"):
polygamma(0, x)
def test_output_dtype(self):
n = scalar("n", dtype="int64")
polygamma(n, scalar("x", dtype="float32")).dtype == "float32"
polygamma(n, scalar("x", dtype="float64")).dtype == "float64"
polygamma(n, scalar("x", dtype="int32")).dtype == "float64"
def test_grad_x(self):
x = scalar("x")
op_grad = grad(polygamma(0, x), wrt=x)
ref_grad = grad(digamma(x), wrt=x)
np.testing.assert_allclose(
op_grad.eval({x: 0.9}),
ref_grad.eval({x: 0.9}),
)
def test_grad_n_undefined(self):
n = scalar(dtype="int64")
with pytest.raises(NullTypeGradError):
grad(polygamma(n, 0.5), wrt=n)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论