提交 b2229fc7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

xfail JAX tests after mandatory "omnistaging" update

上级 64453a33
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import reduce, singledispatch, update_wrapper from functools import reduce, singledispatch, update_wrapper
from warnings import warn from warnings import warn
...@@ -86,6 +87,10 @@ try: ...@@ -86,6 +87,10 @@ try:
jax.config.disable_omnistaging() jax.config.disable_omnistaging()
except AttributeError: except AttributeError:
pass pass
except Exception as e:
# The version might be >= 0.2.12, which means that omnistaging can't be
# disabled
warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor) subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1) incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
...@@ -653,7 +658,6 @@ def jax_funcify_Subtensor(op): ...@@ -653,7 +658,6 @@ def jax_funcify_Subtensor(op):
cdata = cdata[0] cdata = cdata[0]
return x.__getitem__(cdata) return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor return subtensor
......
...@@ -17,3 +17,4 @@ diff-cover ...@@ -17,3 +17,4 @@ diff-cover
pre-commit pre-commit
isort isort
pypolyagamma pypolyagamma
packaging
...@@ -2,6 +2,7 @@ from functools import partial ...@@ -2,6 +2,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse
import aesara.scalar.basic as aes import aesara.scalar.basic as aes
from aesara.compile.function import function from aesara.compile.function import function
...@@ -164,12 +165,20 @@ def test_jax_shape_ops(): ...@@ -164,12 +165,20 @@ def test_jax_shape_ops():
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py(x_fg, [], must_be_device_array=False)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_specify_shape():
x_np = np.zeros((20, 3))
x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3)) x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3)) x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
...@@ -326,6 +335,12 @@ def test_jax_basic_multiout(): ...@@ -326,6 +335,12 @@ def test_jax_basic_multiout():
out_fg = FunctionGraph([x], outs) out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to # Test that a single output of a multi-output `Op` can be used as input to
# another `Op` # another `Op`
x = dvector() x = dvector()
...@@ -335,6 +350,10 @@ def test_jax_basic_multiout(): ...@@ -335,6 +350,10 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [np.r_[1, 2]]) compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_multiple_output(): def test_jax_scan_multiple_output():
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
...@@ -425,6 +444,10 @@ def test_jax_scan_multiple_output(): ...@@ -425,6 +444,10 @@ def test_jax_scan_multiple_output():
compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_tap_output(): def test_jax_scan_tap_output():
a_aet = scalar("a") a_aet = scalar("a")
...@@ -472,12 +495,6 @@ def test_jax_Subtensors(): ...@@ -472,12 +495,6 @@ def test_jax_Subtensors():
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Boolean indices
out_aet = x_aet[x_aet < 0]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Advanced indexing # Advanced indexing
out_aet = x_aet[[1, 2]] out_aet = x_aet[[1, 2]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
...@@ -501,6 +518,24 @@ def test_jax_Subtensors(): ...@@ -501,6 +518,24 @@ def test_jax_Subtensors():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Subtensors_omni():
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5))
# Boolean indices
out_aet = x_aet[x_aet < 0]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
...@@ -659,6 +694,10 @@ def test_jax_MakeVector(): ...@@ -659,6 +694,10 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Reshape(): def test_jax_Reshape():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
...@@ -703,6 +742,10 @@ def test_jax_Dimshuffle(): ...@@ -703,6 +742,10 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Join(): def test_jax_Join():
a = matrix("a") a = matrix("a")
b = matrix("b") b = matrix("b")
...@@ -821,6 +864,10 @@ def test_nnet(): ...@@ -821,6 +864,10 @@ def test_nnet():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics(): def test_tensor_basics():
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
...@@ -970,11 +1017,7 @@ def test_extra_ops(): ...@@ -970,11 +1017,7 @@ def test_extra_ops():
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# This function also cannot take symbolic input.
c = aet.as_tensor(5) c = aet.as_tensor(5)
out = aet_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = aet_extra_ops.fill_diagonal(a, c) out = aet_extra_ops.fill_diagonal(a, c)
...@@ -998,6 +1041,21 @@ def test_extra_ops(): ...@@ -998,6 +1041,21 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
) )
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_extra_ops_omni():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
# This function also cannot take symbolic input.
c = aet.as_tensor(5)
out = aet_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4)) out = aet_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out]) fgraph = FunctionGraph([], [out])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论