Unverified 提交 71592152 authored 作者: Tanish's avatar Tanish 提交者: GitHub

Update `tensor.where` to allow for case with only condition (#844)

上级 d3bd1f15
......@@ -760,7 +760,31 @@ def switch(cond, ift, iff):
"""if cond then ift else iff"""
where = switch
def where(cond, ift=None, iff=None, **kwargs):
"""
where(condition, [ift, iff])
Return elements chosen from `ift` or `iff` depending on `condition`.
Note: When only condition is provided, this function is a shorthand for `as_tensor(condition).nonzero()`.
Parameters
----------
condition : tensor_like, bool
Where True, yield `ift`, otherwise yield `iff`.
x, y : tensor_like
Values from which to choose.
Returns
-------
out : TensorVariable
A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere.
"""
if ift is not None and iff is not None:
return switch(cond, ift, iff, **kwargs)
elif ift is None and iff is None:
return as_tensor(cond).nonzero(**kwargs)
else:
raise ValueError("either both or neither of ift and iff should be given")
@scalar_elemwise
......
......@@ -87,6 +87,7 @@ from pytensor.tensor.basic import (
triu_indices,
triu_indices_from,
vertical_stack,
where,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
......@@ -4608,3 +4609,20 @@ def test_vectorize_join(axis, broadcasting_y):
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
def test_where():
a = np.arange(10)
cond = a < 5
ift = np.pi
iff = np.e
# Test for all 3 inputs
np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval())
# Test for only condition input
for np_output, pt_output in zip(np.where(cond), where(cond)):
np.testing.assert_allclose(np_output, pt_output.eval())
# Test for error
with pytest.raises(ValueError, match="either both"):
where(cond, ift)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论