提交 2ccd9cca authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Separate failing conditions in test_jax_IncSubtensor

Also avoid using `at.arange` in these tests as it always yields `ConcretizationTypeError`s in more recent versions of JAX
上级 2e35e6c4
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
from jax._src.errors import NonConcreteBooleanIndexError
from packaging.version import parse as version_parse from packaging.version import parse as version_parse
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
...@@ -674,15 +675,11 @@ def test_jax_Subtensors_omni(): ...@@ -674,15 +675,11 @@ def test_jax_Subtensors_omni():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
# "Set" basic indices # "Set" basic indices
st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
...@@ -707,7 +704,7 @@ def test_jax_IncSubtensor(): ...@@ -707,7 +704,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
) )
out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at) out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -717,14 +714,8 @@ def test_jax_IncSubtensor(): ...@@ -717,14 +714,8 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Set" boolean indices # "Set" boolean indices
mask_at = at.as_tensor_variable(x_np) > 0 mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
...@@ -753,7 +744,7 @@ def test_jax_IncSubtensor(): ...@@ -753,7 +744,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
) )
out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at) out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -763,18 +754,50 @@ def test_jax_IncSubtensor(): ...@@ -763,18 +754,50 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) # "Increment" boolean indices
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
def test_jax_IncSubtensors_unsupported():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
mask_at = at.as_tensor_variable(x_np) > 0 mask_at = at.as_tensor_variable(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
def test_jax_ifelse(): def test_jax_ifelse():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论