提交 848ce199 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Detect cases where `boolean_indexing_set_or_inc` does not apply

上级 0fced9aa
......@@ -20,9 +20,11 @@ def boolean_indexing_set_or_inc(fgraph, node):
"""
op = node.op
x = node.inputs[0]
y = node.inputs[1]
cond = node.inputs[2]
[x, y, cond] = node.inputs
# This rewrite only works when `y` is a scalar, so it can broadcast to the shape of x[cond]
if y.type.ndim > 0:
return
if not isinstance(cond, TensorVariable):
return
......
......@@ -5,7 +5,10 @@ import pytensor.tensor as at
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.rewriting.jax import boolean_indexing_sum
from pytensor.tensor.rewriting.jax import (
boolean_indexing_set_or_inc,
boolean_indexing_sum,
)
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -216,7 +219,7 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
replaced with an equivalent `Switch` `Op`, using the
`jax_boolean_indexing_set_of_inc` rewrite.
`boolean_indexing_set_of_inc` rewrite.
JAX forces users to re-express this logic manually, so this is an
improvement over its user interface.
......@@ -237,3 +240,12 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, [x_np])
def test_boolean_indexing_set_or_inc_not_applicable():
"""Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply."""
x = at.vector("x")
mask = at.as_tensor(x) > 0
out = at_subtensor.set_subtensor(x[mask], [0, 1, 2])
fg = FunctionGraph([x], [out])
assert boolean_indexing_set_or_inc.transform(fg, fg.outputs[0].owner) is None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论