提交 eddc85fc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow shared RandomState/Generator updates in JAX compiled functions

上级 8171925a
import warnings
from numpy.random import Generator, RandomState from numpy.random import Generator, RandomState
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
...@@ -7,10 +10,48 @@ from aesara.link.basic import JITLinker ...@@ -7,10 +10,48 @@ from aesara.link.basic import JITLinker
class JAXLinker(JITLinker): class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" """A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def fgraph_convert(self, fgraph, **kwargs): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from aesara.link.jax.dispatch import jax_funcify from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.random.type import RandomType
shared_rng_inputs = [
inp
for inp in fgraph.inputs
if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType))
]
# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# be typyfied
if shared_rng_inputs:
warnings.warn(
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
f"in the compiled JAX graph. Instead a copy will be used.",
UserWarning,
)
new_shared_rng_inputs = [
shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs
]
fgraph.replace_all(
zip(shared_rng_inputs, new_shared_rng_inputs),
import_missing=True,
reason="JAXLinker.fgraph_convert",
)
for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
)
return jax_funcify(fgraph, **kwargs) return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)
def jit_compile(self, fn): def jit_compile(self, fn):
import jax import jax
...@@ -32,11 +73,7 @@ class JAXLinker(JITLinker): ...@@ -32,11 +73,7 @@ class JAXLinker(JITLinker):
new_value = jax_typify( new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None) sinput[0], dtype=getattr(sinput[0], "dtype", None)
) )
# We need to remove the reference-based connection to the sinput[0] = new_value
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-JAXified graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput) thunk_inputs.append(sinput)
return thunk_inputs return thunk_inputs
import re
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse from packaging.version import parse as version_parse
import aesara
import aesara.tensor as at import aesara.tensor as at
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
...@@ -79,8 +82,34 @@ def test_RandomStream(): ...@@ -79,8 +82,34 @@ def test_RandomStream():
srng = RandomStream(seed=123) srng = RandomStream(seed=123)
out = srng.normal() - srng.normal() out = srng.normal() - srng.normal()
fn = function([], out, mode=jax_mode) with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
jax_res_1 = fn() jax_res_1 = fn()
jax_res_2 = fn() jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2) assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = aesara.function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()
# Check that original rng variable content was not overwritten when calling jax_typify
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论