提交 d127fc14 authored 作者: kc611's avatar kc611 提交者: Thomas Wiecki

Add a test for JAX conversion of RandomStream output

上级 c3bc2cc8
...@@ -30,6 +30,7 @@ from aesara.tensor.math import max as aet_max ...@@ -30,6 +30,7 @@ from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod from aesara.tensor.math import maximum, prod
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import normal from aesara.tensor.random.basic import normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
from aesara.tensor.type import ( from aesara.tensor.type import (
dscalar, dscalar,
...@@ -47,6 +48,10 @@ from aesara.tensor.type import ( ...@@ -47,6 +48,10 @@ from aesara.tensor.type import (
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_aesara_flags(): def set_aesara_flags():
...@@ -87,10 +92,6 @@ def compare_jax_and_py( ...@@ -87,10 +92,6 @@ def compare_jax_and_py(
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*inputs) jax_res = aesara_jax_fn(*inputs)
...@@ -975,3 +976,14 @@ def test_random(): ...@@ -975,3 +976,14 @@ def test_random():
out = normal(rng=rng) out = normal(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
def test_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论