提交 88476b4f authored 作者: Abhinav-Khot's avatar Abhinav-Khot 提交者: Ricardo Vieira

Add support for negative values in psi

上级 7799bd08
...@@ -378,17 +378,6 @@ class Psi(UnaryScalarOp): ...@@ -378,17 +378,6 @@ class Psi(UnaryScalarOp):
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return """ return """
// For GPU support
#ifdef WITHIN_KERNEL
#define DEVICE WITHIN_KERNEL
#else
#define DEVICE
#endif
#ifndef ga_double
#define ga_double double
#endif
#ifndef M_PI #ifndef M_PI
#define M_PI 3.14159265358979323846 #define M_PI 3.14159265358979323846
#endif #endif
...@@ -397,51 +386,48 @@ class Psi(UnaryScalarOp): ...@@ -397,51 +386,48 @@ class Psi(UnaryScalarOp):
#define _PSIFUNCDEFINED #define _PSIFUNCDEFINED
DEVICE double _psi(ga_double x) { DEVICE double _psi(ga_double x) {
/*taken from /*taken from
Bernardo, J. M. (1976). Algorithm AS 103: Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf http://www.uv.es/~bernardo/1976AppStatist.pdf
*/ */
ga_double y, R, psi_ = 0;
ga_double S = 1.0e-5;
ga_double C = 8.5;
ga_double S3 = 8.333333333e-2;
ga_double S4 = 8.333333333e-3;
ga_double S5 = 3.968253968e-3;
ga_double D1 = -0.5772156649;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
// Use reflection formula double y, R, psi_ = 0;
ga_double pi_x = M_PI * x; double S = 1.0e-5;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x); double C = 8.5;
return _psi(1.0 - x) + M_PI * cot_pi_x; double S3 = 8.333333333e-2;
} double S4 = 8.333333333e-3;
double S5 = 3.968253968e-3;
double D1 = -0.5772156649;
y = x; if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}
// Use reflection formula
ga_double pi_x = M_PI * x;
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) - M_PI * cot_pi_x;
}
if (y <= 0) y = x;
return psi_;
if (y <= S) if (y <= S)
return D1 - 1.0/y; return D1 - 1.0/y;
while (y < C) { while (y < C) {
psi_ = psi_ - 1.0 / y; psi_ = psi_ - 1.0 / y;
y = y + 1; y = y + 1;
} }
R = 1.0 / y; R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ; psi_ = psi_ + log(y) - .5 * R ;
R= R*R; R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
return psi_; return psi_;
} }
#endif #endif
""" """
...@@ -450,8 +436,8 @@ class Psi(UnaryScalarOp): ...@@ -450,8 +436,8 @@ class Psi(UnaryScalarOp):
(x,) = inp (x,) = inp
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
return f"""{z} = dtype = "npy_" + node.outputs[0].dtype
_psi({x});""" return f"{z} = ({dtype}) _psi({x});"
raise NotImplementedError("only floating point is implemented") raise NotImplementedError("only floating point is implemented")
......
...@@ -155,7 +155,7 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads): ...@@ -155,7 +155,7 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"linker", "linker",
["py", "c"], ["py", "cvm"],
) )
def test_psi(linker): def test_psi(linker):
x = float64("x") x = float64("x")
...@@ -164,7 +164,7 @@ def test_psi(linker): ...@@ -164,7 +164,7 @@ def test_psi(linker):
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run")) fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
fn.dprint() fn.dprint()
x_test = np.float64(0.5) x_test = np.float64(0.7)
np.testing.assert_allclose( np.testing.assert_allclose(
fn(x_test), fn(x_test),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论