Unverified 提交 c77d1eff authored 作者: Emeka Okoli's avatar Emeka Okoli 提交者: GitHub

Implement Scan based `filter` helper (#1717)

上级 0b731c27
...@@ -198,3 +198,63 @@ def foldr( ...@@ -198,3 +198,63 @@ def foldr(
name=name, name=name,
return_updates=return_updates, return_updates=return_updates,
) )
def filter(
fn,
sequences,
non_sequences=None,
go_backwards=False,
mode=None,
name=None,
):
"""Construct a `Scan` `Op` that functions like `filter`.
Parameters
----------
fn : callable
Predicate function returning a boolean tensor.
sequences : list
Sequences to filter.
non_sequences : list
Non-iterated arguments passed to `fn`.
go_backwards : bool
Whether to iterate in reverse.
mode : str or None
See ``scan``.
name : str or None
See ``scan``.
Notes
-----
If the predicate function `fn` returns multiple boolean masks (one per sequence),
each mask will be applied to its corresponding sequence. If it returns a single mask,
that mask will be broadcast to all sequences.
"""
mask, _ = scan(
fn=fn,
sequences=sequences,
outputs_info=None,
non_sequences=non_sequences,
go_backwards=go_backwards,
mode=mode,
name=name,
)
if isinstance(mask, (list, tuple)):
# One mask per sequence
if not isinstance(sequences, (list, tuple)):
raise TypeError(
"If multiple masks are returned, sequences must be a list or tuple."
)
if len(mask) != len(sequences):
raise ValueError("Number of masks must match number of sequences.")
filtered_sequences = [seq[m] for seq, m in zip(sequences, mask)]
else:
# Single mask applied to all sequences
if isinstance(sequences, (list, tuple)):
filtered_sequences = [seq[mask] for seq in sequences]
else:
filtered_sequences = sequences[mask]
return filtered_sequences
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, grad, shared from pytensor import config, function, grad, shared
from pytensor.compile.mode import FAST_RUN from pytensor.compile.mode import FAST_RUN
from pytensor.scan.views import filter as pt_filter
from pytensor.scan.views import foldl, foldr from pytensor.scan.views import foldl, foldr
from pytensor.scan.views import map as pt_map from pytensor.scan.views import map as pt_map
from pytensor.scan.views import reduce as pt_reduce from pytensor.scan.views import reduce as pt_reduce
...@@ -166,3 +167,42 @@ def test_foldr_memory_consumption(return_updates): ...@@ -166,3 +167,42 @@ def test_foldr_memory_consumption(return_updates):
gx = grad(o, x) gx = grad(o, x)
f2 = function([], gx) f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,))) utt.assert_allclose(f2(), np.ones((10,)))
def test_filter():
v = pt.vector("v")
def fn(x):
return pt.eq(x % 2, 0)
filtered = pt_filter(fn, v)
f = function([v], filtered, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
vals = rng.integers(0, 10, size=(10,))
expected = vals[vals % 2 == 0]
result = f(vals)
utt.assert_allclose(expected, result)
def test_filter_multiple_masks():
v1 = pt.vector("v1")
v2 = pt.vector("v2")
def fn(x1, x2):
# Mask v1 for even numbers, mask v2 for numbers > 5
return pt.eq(x1 % 2, 0), pt.gt(x2, 5)
filtered_v1, filtered_v2 = pt_filter(fn, [v1, v2])
f = function([v1, v2], [filtered_v1, filtered_v2], allow_input_downcast=True)
vals1 = np.arange(10)
vals2 = np.arange(10)
expected_v1 = vals1[vals1 % 2 == 0]
expected_v2 = vals2[vals2 > 5]
result_v1, result_v2 = f(vals1, vals2)
utt.assert_allclose(expected_v1, result_v1)
utt.assert_allclose(expected_v2, result_v2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论