提交 eddc85fc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow shared RandomState/Generator updates in JAX compiled functions

上级 8171925a
import warnings
from numpy.random import Generator, RandomState
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.graph.basic import Constant
from aesara.link.basic import JITLinker
......@@ -7,10 +10,48 @@ from aesara.link.basic import JITLinker
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def fgraph_convert(self, fgraph, **kwargs):
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.random.type import RandomType
shared_rng_inputs = [
inp
for inp in fgraph.inputs
if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType))
]
# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# be typyfied
if shared_rng_inputs:
warnings.warn(
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
f"in the compiled JAX graph. Instead a copy will be used.",
UserWarning,
)
new_shared_rng_inputs = [
shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs
]
return jax_funcify(fgraph, **kwargs)
fgraph.replace_all(
zip(shared_rng_inputs, new_shared_rng_inputs),
import_missing=True,
reason="JAXLinker.fgraph_convert",
)
for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
)
return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)
def jit_compile(self, fn):
import jax
......@@ -32,11 +73,7 @@ class JAXLinker(JITLinker):
new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-JAXified graphs will have problems.
sinput = [new_value]
sinput[0] = new_value
thunk_inputs.append(sinput)
return thunk_inputs
import re
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara
import aesara.tensor as at
from aesara.compile.function import function
from aesara.compile.sharedvalue import shared
......@@ -79,8 +82,34 @@ def test_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2)
assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = aesara.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论