Unverified 提交 55b00d16 authored 作者: Nirdesh Bhandari's avatar Nirdesh Bhandari 提交者: GitHub

Implement isfinite helper

上级 893a4c74
......@@ -26,6 +26,7 @@ from pytensor.tensor.basic import (
concatenate,
constant,
expand_dims,
ones_like,
stack,
switch,
)
......@@ -881,6 +882,39 @@ def isinf(a):
return isinf_(a)
def isfinite(a):
"""isfinite(a)
Computes element-wise detection of finite values (i.e., not NaN or infinite).
Parameters
----------
a : TensorLike
Input tensor
Returns
-------
TensorVariable
Output tensor of type bool, with 1 (True) where elements are finite,
and 0 (False) elsewhere.
Examples
--------
>>> import pytensor
>>> import pytensor.tensor as pt
>>> import numpy as np
>>> x = pt.vector("x")
>>> f = pytensor.function([x], pt.isfinite(x))
>>> f([1, np.inf, -np.inf, np.nan, 3])
array([ True, False, False, False, True])
"""
a = as_tensor_variable(a)
if a.dtype in discrete_dtypes:
return ones_like(a, dtype="bool")
return ~isnan_(a) & ~isinf_(a)
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
"""
Implement Numpy's ``allclose`` on tensors.
......@@ -4241,6 +4275,7 @@ __all__ = [
"invert",
"iround",
"isclose",
"isfinite",
"isinf",
"isnan",
"isneginf",
......
......@@ -77,6 +77,7 @@ from pytensor.tensor.math import (
expm1,
floor,
isclose,
isfinite,
isinf,
isnan,
isnan_,
......@@ -771,6 +772,28 @@ def test_isnan():
f([[0, 1, 2]])
def test_isfinite():
x_float = matrix(dtype="float32")
data_float = np.array(
[[1.0, np.nan, np.inf], [-np.inf, 0.0, -2.5]], dtype="float32"
)
y_float = isfinite(x_float)
assert y_float.dtype == "bool"
assert_array_equal(y_float.eval({x_float: data_float}), np.isfinite(data_float))
x_int = imatrix()
data_int = np.array([[0, 1, 2], [-1, 5, 10]], dtype="int32")
y_int = isfinite(x_int)
assert y_int.dtype == "bool"
assert_array_equal(y_int.eval({x_int: data_int}), np.isfinite(data_int))
x_bool = matrix(dtype="bool")
data_bool = np.array([[True, False, True], [False, True, False]], dtype="bool")
y_bool = isfinite(x_bool)
assert y_bool.dtype == "bool"
assert_array_equal(y_bool.eval({x_bool: data_bool}), np.isfinite(data_bool))
class TestMaxAndArgmax:
def setup_method(self):
Max.debug = 0
......@@ -3119,7 +3142,7 @@ class TestProd:
assert not unpickled_prod.no_zeros_in_input
class TestIsInfIsNan:
class TestIsInfIsNanIsFinite:
def setup_method(self):
self.test_vals = [
np.array(x, dtype=config.floatX)
......@@ -3159,6 +3182,9 @@ class TestIsInfIsNan:
def test_isnan(self):
self.run_isfunc(isnan, np.isnan)
def test_isfinite(self):
self.run_isfunc(isfinite, np.isfinite)
class TestSumProdReduceDtype:
mode = get_default_mode().excluding("local_cut_useless_reduce")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论