提交 92d5450f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement Polygamma Op

上级 f6be5213
...@@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""): ...@@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""):
return ( return (
NullType( NullType(
"This variable is Null because the grad method for " "This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}" f"input {x_pos} ({x}) of the {op} op is undefined. {comment}"
) )
)() )()
......
...@@ -13,7 +13,7 @@ import scipy.special ...@@ -13,7 +13,7 @@ import scipy.special
import scipy.stats import scipy.stats
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented from pytensor.gradient import grad_not_implemented, grad_undefined
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
from pytensor.scalar.basic import abs as scalar_abs from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
...@@ -473,8 +473,12 @@ class TriGamma(UnaryScalarOp): ...@@ -473,8 +473,12 @@ class TriGamma(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return TriGamma.st_impl(x) return TriGamma.st_impl(x)
def grad(self, inputs, outputs_gradients): def L_op(self, inputs, outputs, outputs_gradients):
raise NotImplementedError() (x,) = inputs
(g_out,) = outputs_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [g_out * polygamma(2, x)]
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
# The implementation has been copied from # The implementation has been copied from
...@@ -541,7 +545,52 @@ class TriGamma(UnaryScalarOp): ...@@ -541,7 +545,52 @@ class TriGamma(UnaryScalarOp):
raise NotImplementedError("only floating point is implemented") raise NotImplementedError("only floating point is implemented")
tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma") # Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
class PolyGamma(BinaryScalarOp):
"""Polygamma function of order n evaluated at x.
It corresponds to the (n+1)th derivative of the log gamma function.
TODO: Because the first input is discrete and the output is continuous,
the default elemwise inplace won't work, as it always tries to store the results in the first input.
"""
nfunc_spec = ("scipy.special.polygamma", 2, 1)
@staticmethod
def output_types_preference(n_type, x_type):
if n_type not in discrete_types:
raise TypeError(
f"Polygamma order parameter must be discrete, got {n_type} dtype"
)
# Scipy doesn't support it
return upgrade_to_float_no_complex(x_type)
@staticmethod
def st_impl(n, x):
return scipy.special.polygamma(n, x)
def impl(self, n, x):
return PolyGamma.st_impl(n, x)
def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs
(g_out,) = output_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [
grad_undefined(self, 0, n),
g_out * self(n + 1, x),
]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
polygamma = PolyGamma(name="polygamma")
class Chi2SF(BinaryScalarOp): class Chi2SF(BinaryScalarOp):
......
...@@ -1369,6 +1369,11 @@ def tri_gamma(a): ...@@ -1369,6 +1369,11 @@ def tri_gamma(a):
"""second derivative of the log gamma function""" """second derivative of the log gamma function"""
@scalar_elemwise
def polygamma(n, x):
"""Polygamma function of order n evaluated at x"""
@scalar_elemwise @scalar_elemwise
def chi2sf(x, k): def chi2sf(x, k):
"""chi squared survival function""" """chi squared survival function"""
...@@ -3008,6 +3013,7 @@ __all__ = [ ...@@ -3008,6 +3013,7 @@ __all__ = [
"psi", "psi",
"digamma", "digamma",
"tri_gamma", "tri_gamma",
"polygamma",
"chi2sf", "chi2sf",
"gammainc", "gammainc",
"gammaincc", "gammaincc",
......
...@@ -52,6 +52,7 @@ from pytensor.tensor.math import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import ( from pytensor.tensor.math import (
add, add,
digamma,
dot, dot,
eq, eq,
erf, erf,
...@@ -68,7 +69,7 @@ from pytensor.tensor.math import ( ...@@ -68,7 +69,7 @@ from pytensor.tensor.math import (
makeKeepDims, makeKeepDims,
) )
from pytensor.tensor.math import max as at_max from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum, mul, neg from pytensor.tensor.math import maximum, mul, neg, polygamma
from pytensor.tensor.math import pow as at_pow from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import ( from pytensor.tensor.math import (
prod, prod,
...@@ -81,7 +82,7 @@ from pytensor.tensor.math import ( ...@@ -81,7 +82,7 @@ from pytensor.tensor.math import (
sub, sub,
) )
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div from pytensor.tensor.math import tri_gamma, true_div
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
alloc_like, alloc_like,
broadcasted_by, broadcasted_by,
...@@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node): ...@@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node):
x = node.inputs[0] x = node.inputs[0]
if x.type.dtype not in complex_dtypes: if x.type.dtype not in complex_dtypes:
return [x] return [x]
local_polygamma_to_digamma = PatternNodeRewriter(
(polygamma, 0, "x"),
(digamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_digamma",
)
register_specialize(local_polygamma_to_digamma)
local_polygamma_to_tri_gamma = PatternNodeRewriter(
(polygamma, 1, "x"),
(tri_gamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_tri_gamma",
)
register_specialize(local_polygamma_to_tri_gamma)
...@@ -20,6 +20,7 @@ from pytensor.tensor.math import ( ...@@ -20,6 +20,7 @@ from pytensor.tensor.math import (
iv, iv,
log, log,
log1mexp, log1mexp,
polygamma,
psi, psi,
sigmoid, sigmoid,
softplus, softplus,
...@@ -178,6 +179,20 @@ def test_tri_gamma(): ...@@ -178,6 +179,20 @@ def test_tri_gamma():
compare_jax_and_py(fg, [np.array([3.0, 5.0])]) compare_jax_and_py(fg, [np.array([3.0, 5.0])])
def test_polygamma():
n = vector("n", dtype="int32")
x = vector("x", dtype="float32")
out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py(
fg,
[
np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"),
],
)
def test_log1mexp(): def test_log1mexp():
x = vector("x") x = vector("x")
out = log1mexp(x) out = log1mexp(x)
......
...@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.scalar import Pow from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
...@@ -69,7 +69,7 @@ from pytensor.tensor.math import ( ...@@ -69,7 +69,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import max as at_max from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq from pytensor.tensor.math import minimum, mul, neg, neq, polygamma
from pytensor.tensor.math import pow as pt_pow from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import ( from pytensor.tensor.math import (
prod, prod,
...@@ -4236,3 +4236,19 @@ def test_logdiffexp(): ...@@ -4236,3 +4236,19 @@ def test_logdiffexp():
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test)) f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
) )
def test_polygamma_specialization():
x = vector("x")
y1 = polygamma(0, x)
y2 = polygamma(1, x)
y3 = polygamma(2, x)
fn = pytensor.function(
[x], [y1, y2, y3], mode=get_default_mode().including("specialize")
)
fn_outs = fn.maker.fgraph.outputs
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
...@@ -7,6 +7,7 @@ from itertools import product ...@@ -7,6 +7,7 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import scipy.special
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
...@@ -64,6 +65,7 @@ from pytensor.tensor.math import ( ...@@ -64,6 +65,7 @@ from pytensor.tensor.math import (
cov, cov,
deg2rad, deg2rad,
dense_dot, dense_dot,
digamma,
dot, dot,
eq, eq,
exp, exp,
...@@ -93,6 +95,7 @@ from pytensor.tensor.math import ( ...@@ -93,6 +95,7 @@ from pytensor.tensor.math import (
neg, neg,
neq, neq,
outer, outer,
polygamma,
power, power,
ptp, ptp,
rad2deg, rad2deg,
...@@ -3470,3 +3473,44 @@ class TestMatMul: ...@@ -3470,3 +3473,44 @@ class TestMatMul:
fn = function([x, y], x @ y, mode="FAST_RUN") fn = function([x, y], x @ y, mode="FAST_RUN")
[node] = fn.maker.fgraph.apply_nodes [node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, Dot22) assert isinstance(node.op, Dot22)
class TestPolyGamma:
def test_basic(self):
n = vector("n", dtype="int64")
x = scalar("x")
np.testing.assert_allclose(
polygamma(n, x).eval({n: [0, 1], x: 0.5}),
scipy.special.polygamma([0, 1], 0.5),
)
def test_continuous_n_raises(self):
n = scalar("n", dtype="float64")
with pytest.raises(TypeError, match="must be discrete"):
polygamma(n, 0.5)
def test_complex_x_raises(self):
x = scalar(dtype="complex128")
with pytest.raises(TypeError, match="complex argument not supported"):
polygamma(0, x)
def test_output_dtype(self):
n = scalar("n", dtype="int64")
polygamma(n, scalar("x", dtype="float32")).dtype == "float32"
polygamma(n, scalar("x", dtype="float64")).dtype == "float64"
polygamma(n, scalar("x", dtype="int32")).dtype == "float64"
def test_grad_x(self):
x = scalar("x")
op_grad = grad(polygamma(0, x), wrt=x)
ref_grad = grad(digamma(x), wrt=x)
np.testing.assert_allclose(
op_grad.eval({x: 0.9}),
ref_grad.eval({x: 0.9}),
)
def test_grad_n_undefined(self):
n = scalar(dtype="int64")
with pytest.raises(NullTypeGradError):
grad(polygamma(n, 0.5), wrt=n)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论