提交 7799bd08 authored 作者: Abhinav-Khot's avatar Abhinav-Khot 提交者: Ricardo Vieira

add negative value support for the digamma function

上级 7584614e
...@@ -389,6 +389,10 @@ class Psi(UnaryScalarOp): ...@@ -389,6 +389,10 @@ class Psi(UnaryScalarOp):
#define ga_double double #define ga_double double
#endif #endif
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#ifndef _PSIFUNCDEFINED #ifndef _PSIFUNCDEFINED
#define _PSIFUNCDEFINED #define _PSIFUNCDEFINED
DEVICE double _psi(ga_double x) { DEVICE double _psi(ga_double x) {
...@@ -396,7 +400,8 @@ class Psi(UnaryScalarOp): ...@@ -396,7 +400,8 @@ class Psi(UnaryScalarOp):
/*taken from /*taken from
Bernardo, J. M. (1976). Algorithm AS 103: Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf */ http://www.uv.es/~bernardo/1976AppStatist.pdf
*/
ga_double y, R, psi_ = 0; ga_double y, R, psi_ = 0;
ga_double S = 1.0e-5; ga_double S = 1.0e-5;
...@@ -406,10 +411,22 @@ class Psi(UnaryScalarOp): ...@@ -406,10 +411,22 @@ class Psi(UnaryScalarOp):
ga_double S5 = 3.968253968e-3; ga_double S5 = 3.968253968e-3;
ga_double D1 = -0.5772156649; ga_double D1 = -0.5772156649;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
// Use reflection formula
ga_double pi_x = M_PI * x;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) + M_PI * cot_pi_x;
}
y = x; y = x;
if (y <= 0.0) if (y <= 0)
return psi_; return psi_;
if (y <= S) if (y <= S)
return D1 - 1.0/y; return D1 - 1.0/y;
......
...@@ -2,6 +2,7 @@ import itertools ...@@ -2,6 +2,7 @@ import itertools
import numpy as np import numpy as np
import pytest import pytest
import scipy
import scipy.special as sp import scipy.special as sp
import pytensor.tensor as pt import pytensor.tensor as pt
...@@ -19,6 +20,7 @@ from pytensor.scalar.math import ( ...@@ -19,6 +20,7 @@ from pytensor.scalar.math import (
gammal, gammal,
gammau, gammau,
hyp2f1, hyp2f1,
psi,
) )
from tests.link.test_link import make_function from tests.link.test_link import make_function
...@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads): ...@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
(var.owner and isinstance(var.owner.op, ScalarLoop)) (var.owner and isinstance(var.owner.op, ScalarLoop))
for var in ancestors(grad) for var in ancestors(grad)
) )
@pytest.mark.parametrize(
"linker",
["py", "c"],
)
def test_psi(linker):
x = float64("x")
out = psi(x)
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
fn.dprint()
x_test = np.float64(0.5)
np.testing.assert_allclose(
fn(x_test),
scipy.special.psi(x_test),
strict=True,
)
np.testing.assert_allclose(
fn(-x_test),
scipy.special.psi(-x_test),
strict=True,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论