提交 b2229fc7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

xfail JAX tests after mandatory "omnistaging" update

上级 64453a33
import warnings
from collections.abc import Sequence
from functools import reduce, singledispatch, update_wrapper
from warnings import warn
......@@ -86,6 +87,10 @@ try:
jax.config.disable_omnistaging()
except AttributeError:
pass
except Exception as e:
# The version might be >= 0.2.12, which means that omnistaging can't be
# disabled
warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
......@@ -653,7 +658,6 @@ def jax_funcify_Subtensor(op):
cdata = cdata[0]
return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor
......
......@@ -2,6 +2,7 @@ from functools import partial
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara.scalar.basic as aes
from aesara.compile.function import function
......@@ -164,12 +165,20 @@ def test_jax_shape_ops():
compare_jax_and_py(x_fg, [], must_be_device_array=False)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_specify_shape():
x_np = np.zeros((20, 3))
x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3))
x_fg = FunctionGraph([], [x])
......@@ -326,6 +335,12 @@ def test_jax_basic_multiout():
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
......@@ -335,6 +350,10 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
......@@ -425,6 +444,10 @@ def test_jax_scan_multiple_output():
compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_tap_output():
a_aet = scalar("a")
......@@ -472,12 +495,6 @@ def test_jax_Subtensors():
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Boolean indices
out_aet = x_aet[x_aet < 0]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Advanced indexing
out_aet = x_aet[[1, 2]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
......@@ -501,6 +518,24 @@ def test_jax_Subtensors():
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Subtensors_omni():
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5))
# Boolean indices
out_aet = x_aet[x_aet < 0]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
......@@ -659,6 +694,10 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Reshape():
a = vector("a")
x = reshape(a, (2, 2))
......@@ -703,6 +742,10 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Join():
a = matrix("a")
b = matrix("b")
......@@ -821,6 +864,10 @@ def test_nnet():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
......@@ -970,11 +1017,7 @@ def test_extra_ops():
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# This function also cannot take symbolic input.
c = aet.as_tensor(5)
out = aet_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = aet_extra_ops.fill_diagonal(a, c)
......@@ -998,6 +1041,21 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_extra_ops_omni():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
# This function also cannot take symbolic input.
c = aet.as_tensor(5)
out = aet_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论