提交 684a929e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Luciano Paz

Remove false positive check for supported Subtensors operations in JAX

The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
上级 17fa8b13
......@@ -31,36 +31,20 @@ slice length.
"""
def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant
from pytensor.tensor.variable import TensorVariable
ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist:
if isinstance(idx, TensorVariable):
if idx.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
elif isinstance(idx, slice):
for slice_arg in (idx.start, idx.stop, idx.step):
if slice_arg is not None and not isinstance(slice_arg, Constant):
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)
@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
subtensor_assert_indices_jax_compatible(node, idx_list)
def subtensor_constant(x, *ilists):
def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return subtensor_constant
return subtensor
@jax_funcify.register(IncSubtensor)
......
......@@ -5,6 +5,7 @@ import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor import tensor
from pytensor.tensor.rewriting.jax import (
boolean_indexing_set_or_inc,
boolean_indexing_sum,
......@@ -13,54 +14,62 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_Subtensor_constant():
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)
# Basic indices
x_pt = pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
# Advanced indexing
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
# Advanced and basic indexing
out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
# Flipping
out_pt = x_pt[::-1]
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
# Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
......@@ -73,8 +82,10 @@ def test_jax_Subtensor_dynamic():
compare_jax_and_py(out_fg, [1])
def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks."""
def test_jax_Subtensor_dynamic_boolean_mask():
"""JAX does not support resizing arrays with dynamic boolean masks."""
from jax.errors import NonConcreteBooleanIndexError
x_pt = pt.vector("x", dtype="float64")
out_pt = x_pt[x_pt < 0]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
......@@ -82,7 +93,7 @@ def test_jax_Subtensor_boolean_mask():
out_fg = FunctionGraph([x_pt], [out_pt])
x_pt_test = np.arange(-5, 5)
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
with pytest.raises(NonConcreteBooleanIndexError):
compare_jax_and_py(out_fg, [x_pt_test])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论