提交 0d1f65f8 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Raise when the `RandomVariable` will not compile

上级 a110e82b
......@@ -8,12 +8,39 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
import pytensor.tensor.random.basic as aer
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.tensor.shape import Shape
from pytensor.tensor.shape import Shape, Shape_i
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
SIZE_NOT_COMPATIBLE = """JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants:
>>> import pytensor.tensor as at
>>> x_rv = at.random.normal(0, 1, size=(3, 2))
or the shape of an array:
>>> m = at.matrix()
>>> x_rv = at.random.normal(0, 1, size=m.shape)
"""
def assert_size_argument_jax_compatible(node):
"""Assert whether the current node can be compiled.
JAX can JIT-compile `jax.random` functions when the `size` argument
is a concrete value, i.e. either a constant or the shape of any
traced value.
"""
size = node.inputs[1]
size_op = size.owner.op
if not isinstance(size_op, (Shape, Shape_i)):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
......@@ -65,12 +92,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if None in out_size:
if not isinstance(node.inputs[1].owner.op, Shape):
raise NotImplementedError(
"""JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants, or the shape of an array.
"""
)
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
......
......@@ -449,14 +449,33 @@ def test_random_concrete_shape():
"""
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
f = at.random.normal(0, 1, size=(3,), rng=rng)
g = at.random.normal(f, 1, size=x_at.shape, rng=rng)
g_fn = function([x_at], g, mode=jax_mode)
_ = g_fn(np.ones((2, 3)))
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
# This should compile, and `size_at` be passed to the list of `static_argnums`.
with pytest.raises(NotImplementedError):
size_at = at.scalar()
g = at.random.normal(f, 1, size=size_at, rng=rng)
g_fn = function([size_at], g, mode=jax_mode)
_ = g_fn(10)
@pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,)
@pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,)
@pytest.mark.xfail(reason="`size_at` should be specified as a static argument")
def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论