提交 2ccd9cca authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Separate failing conditions in test_jax_IncSubtensor

Also avoid using `at.arange` in these tests as it always yields `ConcretizationTypeError`s in more recent versions of JAX
上级 2e35e6c4
......@@ -3,6 +3,7 @@ from typing import Optional
import numpy as np
import pytest
from jax._src.errors import NonConcreteBooleanIndexError
from packaging.version import parse as version_parse
import aesara.scalar.basic as aes
......@@ -674,15 +675,11 @@ def test_jax_Subtensors_omni():
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_IncSubtensor():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
# "Set" basic indices
st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
......@@ -707,7 +704,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
......@@ -717,14 +714,8 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Set" boolean indices
mask_at = at.as_tensor_variable(x_np) > 0
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
......@@ -753,7 +744,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
......@@ -763,17 +754,49 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
# "Increment" boolean indices
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
def test_jax_IncSubtensors_unsupported():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
mask_at = at.as_tensor_variable(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论