提交 f52c5678 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Raise when trying to slice with a dynamic length

上级 97bd1ac6
...@@ -28,31 +28,43 @@ can be re-expressed as: ...@@ -28,31 +28,43 @@ can be re-expressed as:
>>> y_at = at.where(x_at > 0, x_at, 0).sum() >>> y_at = at.where(x_at > 0, x_at, 0).sum()
""" """
DYNAMIC_SLICE_LENGTH_ERROR = """JAX does not support slicing arrays with a dynamic
slice length.
"""
def assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant
from pytensor.tensor.var import TensorVariable
def assert_indices_jax_compatible(node): ilist = indices_from_subtensor(node.inputs[1:], idx_list)
ilist = node.inputs[1] for idx in ilist:
if ilist.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR) if isinstance(idx, TensorVariable):
if idx.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
elif isinstance(idx, slice):
for slice_arg in (idx.start, idx.stop, idx.step):
if slice_arg is not None and not isinstance(slice_arg, Constant):
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)
@jax_funcify.register(Subtensor) @jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor) @jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1) @jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs): def jax_funcify_Subtensor(op, node, **kwargs):
assert_indices_jax_compatible(node)
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
assert_indices_jax_compatible(node, idx_list)
def subtensor(x, *ilists): def subtensor_constant(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list) indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
return x.__getitem__(indices) return x.__getitem__(indices)
return subtensor return subtensor_constant
@jax_funcify.register(IncSubtensor) @jax_funcify.register(IncSubtensor)
......
import jax
import numpy as np import numpy as np
import pytest import pytest
from jax._src.errors import NonConcreteBooleanIndexError from jax._src.errors import NonConcreteBooleanIndexError
from packaging.version import parse as version_parse
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -11,7 +9,7 @@ from pytensor.tensor import subtensor as at_subtensor ...@@ -11,7 +9,7 @@ from pytensor.tensor import subtensor as at_subtensor
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_Subtensors(): def test_jax_Subtensor_constant():
# Basic indices # Basic indices
x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
out_at = x_at[1, 2, 0] out_at = x_at[1, 2, 0]
...@@ -19,6 +17,16 @@ def test_jax_Subtensors(): ...@@ -19,6 +17,16 @@ def test_jax_Subtensors():
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_at = x_at[1:, 1, :]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = x_at[:2, 1, :]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = x_at[1:2, 1, :] out_at = x_at[1:2, 1, :]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor) assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
...@@ -46,6 +54,21 @@ def test_jax_Subtensors(): ...@@ -46,6 +54,21 @@ def test_jax_Subtensors():
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Flipping
out_at = x_at[::-1]
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
def test_jax_Subtensor_dynamic():
a = at.iscalar("a")
x = at.arange(3)
out_at = x[:a]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([a], [out_at])
compare_jax_and_py(out_fg, [1])
def test_jax_Subtensor_boolean_mask(): def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks.""" """JAX does not support resizing arrays with boolean masks."""
...@@ -53,7 +76,7 @@ def test_jax_Subtensor_boolean_mask(): ...@@ -53,7 +76,7 @@ def test_jax_Subtensor_boolean_mask():
out_at = x_at[x_at < 0] out_at = x_at[x_at < 0]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论