Unverified 提交 ad55b69f authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Implement `tensor.special.logit` helper (#645)

上级 d28a5d0a
......@@ -8,7 +8,7 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.math import gamma, gammaln, neg, sum
from pytensor.tensor.math import gamma, gammaln, log, neg, sum
class SoftmaxGrad(COp):
......@@ -780,6 +780,14 @@ def factorial(n):
return gamma(n + 1)
def logit(x):
"""
Logit function.
"""
return log(x / (1 - x))
def beta(a, b):
"""
Beta function.
......@@ -801,6 +809,7 @@ __all__ = [
"log_softmax",
"poch",
"factorial",
"logit",
"beta",
"betaln",
]
......@@ -3,6 +3,7 @@ import pytest
from scipy.special import beta as scipy_beta
from scipy.special import factorial as scipy_factorial
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logit as scipy_logit
from scipy.special import poch as scipy_poch
from scipy.special import softmax as scipy_softmax
......@@ -18,6 +19,7 @@ from pytensor.tensor.special import (
betaln,
factorial,
log_softmax,
logit,
poch,
softmax,
)
......@@ -206,6 +208,18 @@ def test_factorial(n):
)
def test_logit():
x = vector("x")
actual_fn = function([x], logit(x), allow_input_downcast=True)
x_test = np.linspace(0, 1)
actual = actual_fn(x_test)
expected = scipy_logit(x_test)
np.testing.assert_allclose(
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
def test_beta():
_a, _b = vectors("a", "b")
actual_fn = function([_a, _b], beta(_a, _b))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论