提交 2c91b5a3 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add support for NumPy Generator types in JAX backend

上级 5611cf71
......@@ -6,7 +6,8 @@ import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from numpy.random import RandomState
from numpy.random import Generator, RandomState
from numpy.random.bit_generator import _coerce_to_uint32_array
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
......@@ -105,6 +106,33 @@ def jax_typify_ndarray(data, dtype=None, **kwargs):
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2]
# The "state" and "inc" values in a NumPy `Generator` are 128 bits, which
# JAX can't handle, so we split these values into arrays of 32 bit integers
# and then combine the first two into a single 64 bit integers.
#
# XXX: Depending on how we expect these values to be used, is this approach
# reasonable?
#
# TODO: We might as well remove these altogether, since this conversion
# should only occur once (e.g. when the graph is converted/JAX-compiled),
# and, from then on, we use the custom "jax_state" value.
inc_32 = _coerce_to_uint32_array(state["state"]["inc"])
state_32 = _coerce_to_uint32_array(state["state"]["state"])
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
state["state"]["state"] = state_32[0] << 32 | state_32[1]
return state
......@@ -999,7 +1027,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op, **kwargs):
def jax_funcify_RandomVariable(op, node, **kwargs):
name = op.name
if not hasattr(jax.random, name):
......@@ -1007,13 +1035,15 @@ def jax_funcify_RandomVariable(op, **kwargs):
f"No JAX conversion for the given distribution: {name}"
)
def random_variable(rng, size, dtype, *args):
prng = jax.random.PRNGKey(rng["state"]["key"][0])
dtype = jnp.dtype(dtype)
dtype = node.outputs[1].dtype
def random_variable(rng, size, dtype_num, *args):
if not op.inplace:
rng = rng.copy()
prng = rng["jax_state"]
data = getattr(jax.random, name)(key=prng, shape=size)
smpl_value = jnp.array(data, dtype=dtype)
prng = jax.random.split(prng, num=1)[0]
jax.ops.index_update(rng["state"]["key"], 0, prng[0])
rng["jax_state"] = jax.random.split(prng, num=1)[0]
return (rng, smpl_value)
return random_variable
from numpy.random import RandomState
from numpy.random import Generator, RandomState
from aesara.graph.basic import Constant
from aesara.link.basic import JITLinker
......@@ -28,7 +28,7 @@ class JAXLinker(JITLinker):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState):
if isinstance(sinput[0], (RandomState, Generator)):
new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
......
......@@ -1188,12 +1188,38 @@ def test_extra_ops_omni():
compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="The RNG states are not 1:1", raises=AssertionError)
def test_random():
rng = shared(np.random.RandomState(123))
out = normal(rng=rng)
@pytest.mark.parametrize(
"at_dist, dist_params, rng, size",
[
(
normal,
(),
shared(np.random.RandomState(123)),
10000,
),
(
normal,
(),
shared(np.random.default_rng(123)),
10000,
),
],
)
def test_random_stats(at_dist, dist_params, rng, size):
# The RNG states are not 1:1, so the best we can do is check some summary
# statistics of the samples
out = normal(*dist_params, rng=rng, size=size)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
def assert_fn(x, y):
(x,) = x
(y,) = y
assert x.dtype.kind == y.dtype.kind
d = 2 if config.floatX == "float64" else 1
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
def test_random_unimplemented():
......@@ -1218,7 +1244,6 @@ def test_random_unimplemented():
compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
def test_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
......@@ -1228,11 +1253,3 @@ def test_RandomStream():
jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
def test_random_generators():
rng = shared(np.random.default_rng(123))
out = normal(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论