提交 c3bc2cc8 authored 作者: kc611's avatar kc611 提交者: Thomas Wiecki

Implement JAX conversions for RandomVariables and RandomState types

上级 8e0b1560
......@@ -5,6 +5,8 @@ from warnings import warn
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
......@@ -52,6 +54,7 @@ from aesara.tensor.nlinalg import (
)
from aesara.tensor.nnet.basic import Softmax
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import ( # This is essentially `np.take`; Boolean mask indexing and setting
......@@ -66,6 +69,10 @@ from aesara.tensor.subtensor import ( # This is essentially `np.take`; Boolean
from aesara.tensor.type_other import MakeSlice
# For use with JAX since JAX doesn't support 'str' arguments
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
if config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
......@@ -125,21 +132,18 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
return jax_typify(inputs[idx], i_dtype)
input_f = jax_inputs_func
elif i.owner is None:
# This input is something like a `aesara.graph.basic.Constant`
# This input is something like an `aesara.graph.basic.Constant`
i_dtype = getattr(i, "dtype", None)
i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
return jax_typify(i_data, i_dtype)
input_f = jax_data_func
else:
......@@ -171,7 +175,6 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
......@@ -184,9 +187,31 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
return jax_funcs
@singledispatch
def jax_typify(data, dtype):
"""Convert instances of Aesara `Type`s to JAX types."""
if dtype is None:
return data
if dtype is not None:
return jnp.array(data, dtype=dtype)
raise NotImplementedError(f"No JAX conversion for data and dtype: {data}, {dtype}")
@jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype):
return jnp.array(data, dtype=dtype)
@jax_typify.register(RandomState)
def jax_typify_RandomState(state, dtype):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
return state
@singledispatch
def jax_funcify(op):
"""Create a JAX "perform" function for an Aesara `Variable` and its `Op`."""
"""Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
......@@ -617,8 +642,6 @@ def jax_funcify_Subtensor(op):
else:
cdata = ilists
# breakpoint()
if len(cdata) == 1:
cdata = cdata[0]
......@@ -1082,3 +1105,24 @@ def jax_funcify_BatchedDot(op):
return jnp.einsum("nij,njk->nik", a, b)
return batched_dot
@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op):
name = op.name
if not hasattr(jax.random, name):
raise NotImplementedError(
f"No JAX conversion for the given distribution: {name}"
)
def random_variable(rng, size, dtype, *args):
prng = jax.random.PRNGKey(rng["state"]["key"][0])
dtype = jnp.dtype(dtype)
data = getattr(jax.random, name)(key=prng, shape=size)
smpl_value = jnp.array(data, dtype=dtype)
prng = jax.random.split(prng, num=1)[0]
jax.ops.index_update(rng["state"]["key"], 0, prng[0])
return (rng, smpl_value)
return random_variable
from collections.abc import Sequence
from warnings import warn
from numpy.random import RandomState
from aesara.graph.basic import Constant
from aesara.link.basic import Container, PerformLinker
from aesara.link.utils import gc_helper, map_storage, streamline
......@@ -44,7 +46,7 @@ class JAXLinker(PerformLinker):
"""
import jax
from aesara.link.jax.jax_dispatch import jax_funcify
from aesara.link.jax.jax_dispatch import jax_funcify, jax_typify
output_nodes = [o.owner for o in self.fgraph.outputs]
......@@ -59,7 +61,17 @@ class JAXLinker(PerformLinker):
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState):
new_value = jax_typify(sinput[0], getattr(sinput[0], "dtype", None))
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-JAXified graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)
thunks = []
......
......@@ -7,7 +7,7 @@ import aesara.scalar.basic as aes
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.compile.sharedvalue import shared
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
......@@ -29,6 +29,7 @@ from aesara.tensor.math import clip, cosh, gammaln, log
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import normal
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
from aesara.tensor.type import (
dscalar,
......@@ -90,7 +91,8 @@ def compare_jax_and_py(
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*inputs)
if must_be_device_array:
......@@ -101,7 +103,7 @@ def compare_jax_and_py(
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
aesara_py_fn = function(fgraph.inputs, fgraph.outputs, mode=py_mode)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*inputs)
if len(fgraph.outputs) > 1:
......@@ -965,3 +967,11 @@ def test_extra_ops():
)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="The RNG states are not 1:1", raises=AssertionError)
def test_random():
rng = shared(np.random.RandomState(123))
out = normal(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论