提交 2c91b5a3 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add support for NumPy Generator types in JAX backend

上级 5611cf71
...@@ -6,7 +6,8 @@ import jax ...@@ -6,7 +6,8 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.scipy as jsp import jax.scipy as jsp
import numpy as np import numpy as np
from numpy.random import RandomState from numpy.random import Generator, RandomState
from numpy.random.bit_generator import _coerce_to_uint32_array
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -105,6 +106,33 @@ def jax_typify_ndarray(data, dtype=None, **kwargs): ...@@ -105,6 +106,33 @@ def jax_typify_ndarray(data, dtype=None, **kwargs):
def jax_typify_RandomState(state, **kwargs): def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False) state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2]
# The "state" and "inc" values in a NumPy `Generator` are 128 bits, which
# JAX can't handle, so we split these values into arrays of 32 bit integers
# and then combine the first two into a single 64 bit integers.
#
# XXX: Depending on how we expect these values to be used, is this approach
# reasonable?
#
# TODO: We might as well remove these altogether, since this conversion
# should only occur once (e.g. when the graph is converted/JAX-compiled),
# and, from then on, we use the custom "jax_state" value.
inc_32 = _coerce_to_uint32_array(state["state"]["inc"])
state_32 = _coerce_to_uint32_array(state["state"]["state"])
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
state["state"]["state"] = state_32[0] << 32 | state_32[1]
return state return state
...@@ -999,7 +1027,7 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -999,7 +1027,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
@jax_funcify.register(RandomVariable) @jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op, **kwargs): def jax_funcify_RandomVariable(op, node, **kwargs):
name = op.name name = op.name
if not hasattr(jax.random, name): if not hasattr(jax.random, name):
...@@ -1007,13 +1035,15 @@ def jax_funcify_RandomVariable(op, **kwargs): ...@@ -1007,13 +1035,15 @@ def jax_funcify_RandomVariable(op, **kwargs):
f"No JAX conversion for the given distribution: {name}" f"No JAX conversion for the given distribution: {name}"
) )
def random_variable(rng, size, dtype, *args): dtype = node.outputs[1].dtype
prng = jax.random.PRNGKey(rng["state"]["key"][0])
dtype = jnp.dtype(dtype) def random_variable(rng, size, dtype_num, *args):
if not op.inplace:
rng = rng.copy()
prng = rng["jax_state"]
data = getattr(jax.random, name)(key=prng, shape=size) data = getattr(jax.random, name)(key=prng, shape=size)
smpl_value = jnp.array(data, dtype=dtype) smpl_value = jnp.array(data, dtype=dtype)
prng = jax.random.split(prng, num=1)[0] rng["jax_state"] = jax.random.split(prng, num=1)[0]
jax.ops.index_update(rng["state"]["key"], 0, prng[0])
return (rng, smpl_value) return (rng, smpl_value)
return random_variable return random_variable
from numpy.random import RandomState from numpy.random import Generator, RandomState
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
...@@ -28,7 +28,7 @@ class JAXLinker(JITLinker): ...@@ -28,7 +28,7 @@ class JAXLinker(JITLinker):
thunk_inputs = [] thunk_inputs = []
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
sinput = storage_map[n] sinput = storage_map[n]
if isinstance(sinput[0], RandomState): if isinstance(sinput[0], (RandomState, Generator)):
new_value = jax_typify( new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None) sinput[0], dtype=getattr(sinput[0], "dtype", None)
) )
......
...@@ -1188,12 +1188,38 @@ def test_extra_ops_omni(): ...@@ -1188,12 +1188,38 @@ def test_extra_ops_omni():
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="The RNG states are not 1:1", raises=AssertionError) @pytest.mark.parametrize(
def test_random(): "at_dist, dist_params, rng, size",
rng = shared(np.random.RandomState(123)) [
out = normal(rng=rng) (
normal,
(),
shared(np.random.RandomState(123)),
10000,
),
(
normal,
(),
shared(np.random.default_rng(123)),
10000,
),
],
)
def test_random_stats(at_dist, dist_params, rng, size):
# The RNG states are not 1:1, so the best we can do is check some summary
# statistics of the samples
out = normal(*dist_params, rng=rng, size=size)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
def assert_fn(x, y):
(x,) = x
(y,) = y
assert x.dtype.kind == y.dtype.kind
d = 2 if config.floatX == "float64" else 1
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
def test_random_unimplemented(): def test_random_unimplemented():
...@@ -1218,7 +1244,6 @@ def test_random_unimplemented(): ...@@ -1218,7 +1244,6 @@ def test_random_unimplemented():
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
def test_RandomStream(): def test_RandomStream():
srng = RandomStream(seed=123) srng = RandomStream(seed=123)
out = srng.normal() - srng.normal() out = srng.normal() - srng.normal()
...@@ -1228,11 +1253,3 @@ def test_RandomStream(): ...@@ -1228,11 +1253,3 @@ def test_RandomStream():
jax_res_2 = fn() jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2) assert np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
def test_random_generators():
rng = shared(np.random.default_rng(123))
out = normal(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论