提交 7799bd08 authored 作者: Abhinav-Khot's avatar Abhinav-Khot 提交者: Ricardo Vieira

add negative value support for the digamma function

上级 7584614e
......@@ -389,6 +389,10 @@ class Psi(UnaryScalarOp):
#define ga_double double
#endif
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#ifndef _PSIFUNCDEFINED
#define _PSIFUNCDEFINED
DEVICE double _psi(ga_double x) {
......@@ -396,7 +400,8 @@ class Psi(UnaryScalarOp):
/*taken from
Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf */
http://www.uv.es/~bernardo/1976AppStatist.pdf
*/
ga_double y, R, psi_ = 0;
ga_double S = 1.0e-5;
......@@ -406,9 +411,21 @@ class Psi(UnaryScalarOp):
ga_double S5 = 3.968253968e-3;
ga_double D1 = -0.5772156649;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
// Use reflection formula
ga_double pi_x = M_PI * x;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) + M_PI * cot_pi_x;
}
y = x;
if (y <= 0.0)
if (y <= 0)
return psi_;
if (y <= S)
......
......@@ -2,6 +2,7 @@ import itertools
import numpy as np
import pytest
import scipy
import scipy.special as sp
import pytensor.tensor as pt
......@@ -19,6 +20,7 @@ from pytensor.scalar.math import (
gammal,
gammau,
hyp2f1,
psi,
)
from tests.link.test_link import make_function
......@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
(var.owner and isinstance(var.owner.op, ScalarLoop))
for var in ancestors(grad)
)
@pytest.mark.parametrize(
"linker",
["py", "c"],
)
def test_psi(linker):
x = float64("x")
out = psi(x)
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
fn.dprint()
x_test = np.float64(0.5)
np.testing.assert_allclose(
fn(x_test),
scipy.special.psi(x_test),
strict=True,
)
np.testing.assert_allclose(
fn(-x_test),
scipy.special.psi(-x_test),
strict=True,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论