提交 88476b4f authored 作者: Abhinav-Khot's avatar Abhinav-Khot 提交者: Ricardo Vieira

Add support for negative values in psi

上级 7799bd08
......@@ -378,17 +378,6 @@ class Psi(UnaryScalarOp):
def c_support_code(self, **kwargs):
return """
// For GPU support
#ifdef WITHIN_KERNEL
#define DEVICE WITHIN_KERNEL
#else
#define DEVICE
#endif
#ifndef ga_double
#define ga_double double
#endif
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
......@@ -397,51 +386,48 @@ class Psi(UnaryScalarOp):
#define _PSIFUNCDEFINED
DEVICE double _psi(ga_double x) {
/*taken from
Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf
*/
ga_double y, R, psi_ = 0;
ga_double S = 1.0e-5;
ga_double C = 8.5;
ga_double S3 = 8.333333333e-2;
ga_double S4 = 8.333333333e-3;
ga_double S5 = 3.968253968e-3;
ga_double D1 = -0.5772156649;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
/*taken from
Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf
*/
// Use reflection formula
ga_double pi_x = M_PI * x;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) + M_PI * cot_pi_x;
}
double y, R, psi_ = 0;
double S = 1.0e-5;
double C = 8.5;
double S3 = 8.333333333e-2;
double S4 = 8.333333333e-3;
double S5 = 3.968253968e-3;
double D1 = -0.5772156649;
y = x;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
// Use reflection formula
ga_double pi_x = M_PI * x;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) - M_PI * cot_pi_x;
}
if (y <= 0)
return psi_;
y = x;
if (y <= S)
return D1 - 1.0/y;
if (y <= S)
return D1 - 1.0/y;
while (y < C) {
psi_ = psi_ - 1.0 / y;
y = y + 1;
}
while (y < C) {
psi_ = psi_ - 1.0 / y;
y = y + 1;
}
R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ;
R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ;
R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
return psi_;
return psi_;
}
#endif
"""
......@@ -450,8 +436,8 @@ class Psi(UnaryScalarOp):
(x,) = inp
(z,) = out
if node.inputs[0].type in float_types:
return f"""{z} =
_psi({x});"""
dtype = "npy_" + node.outputs[0].dtype
return f"{z} = ({dtype}) _psi({x});"
raise NotImplementedError("only floating point is implemented")
......
......@@ -155,7 +155,7 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
@pytest.mark.parametrize(
"linker",
["py", "c"],
["py", "cvm"],
)
def test_psi(linker):
x = float64("x")
......@@ -164,7 +164,7 @@ def test_psi(linker):
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
fn.dprint()
x_test = np.float64(0.5)
x_test = np.float64(0.7)
np.testing.assert_allclose(
fn(x_test),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论