Unverified 提交 ad55b69f authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Implement `tensor.special.logit` helper (#645)

上级 d28a5d0a
...@@ -8,7 +8,7 @@ from pytensor.graph.replace import _vectorize_node ...@@ -8,7 +8,7 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.math import gamma, gammaln, neg, sum from pytensor.tensor.math import gamma, gammaln, log, neg, sum
class SoftmaxGrad(COp): class SoftmaxGrad(COp):
...@@ -780,6 +780,14 @@ def factorial(n): ...@@ -780,6 +780,14 @@ def factorial(n):
return gamma(n + 1) return gamma(n + 1)
def logit(x):
"""
Logit function.
"""
return log(x / (1 - x))
def beta(a, b): def beta(a, b):
""" """
Beta function. Beta function.
...@@ -801,6 +809,7 @@ __all__ = [ ...@@ -801,6 +809,7 @@ __all__ = [
"log_softmax", "log_softmax",
"poch", "poch",
"factorial", "factorial",
"logit",
"beta", "beta",
"betaln", "betaln",
] ]
...@@ -3,6 +3,7 @@ import pytest ...@@ -3,6 +3,7 @@ import pytest
from scipy.special import beta as scipy_beta from scipy.special import beta as scipy_beta
from scipy.special import factorial as scipy_factorial from scipy.special import factorial as scipy_factorial
from scipy.special import log_softmax as scipy_log_softmax from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logit as scipy_logit
from scipy.special import poch as scipy_poch from scipy.special import poch as scipy_poch
from scipy.special import softmax as scipy_softmax from scipy.special import softmax as scipy_softmax
...@@ -18,6 +19,7 @@ from pytensor.tensor.special import ( ...@@ -18,6 +19,7 @@ from pytensor.tensor.special import (
betaln, betaln,
factorial, factorial,
log_softmax, log_softmax,
logit,
poch, poch,
softmax, softmax,
) )
...@@ -206,6 +208,18 @@ def test_factorial(n): ...@@ -206,6 +208,18 @@ def test_factorial(n):
) )
def test_logit():
x = vector("x")
actual_fn = function([x], logit(x), allow_input_downcast=True)
x_test = np.linspace(0, 1)
actual = actual_fn(x_test)
expected = scipy_logit(x_test)
np.testing.assert_allclose(
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
def test_beta(): def test_beta():
_a, _b = vectors("a", "b") _a, _b = vectors("a", "b")
actual_fn = function([_a, _b], beta(_a, _b)) actual_fn = function([_a, _b], beta(_a, _b))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论