提交 a824aee4 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Raise when an array is resized with a boolean mask

上级 aa7ce086
......@@ -13,10 +13,33 @@ from pytensor.tensor.subtensor import (
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
masks. In some cases, however, it is possible to re-express your model
in a form that JAX can compile:
>>> import pytensor.tensor as at
>>> x_at = at.vector('x')
>>> y_at = x_at[x_at > 0].sum()
can be re-expressed as:
>>> import pytensor.tensor as at
>>> x_at = at.vector('x')
>>> y_at = at.where(x_at > 0, x_at, 0).sum()
"""
def assert_indices_jax_compatible(node):
ilist = node.inputs[1]
if ilist.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, **kwargs):
def jax_funcify_Subtensor(op, node, **kwargs):
assert_indices_jax_compatible(node)
idx_list = getattr(op, "idx_list", None)
......
......@@ -47,16 +47,24 @@ def test_jax_Subtensors():
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Subtensors_omni():
x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5))
# Boolean indices
def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks."""
x_at = at.arange(-5, 5)
out_at = x_at[x_at < 0]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
with pytest.raises(NotImplementedError):
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side."
)
def test_jax_Subtensor_boolean_mask_reexpressible():
"""Some boolean logic can be re-expressed and JIT-compiled"""
x_at = at.arange(-5, 5)
out_at = x_at[x_at < 0].sum()
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论