提交 bb028ae2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Inline static size inputs in JAX implementation of RandomVariables

This gets around some limitations in JAX jitting system
上级 863efc01
......@@ -8,6 +8,7 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
)
import pytensor.tensor.random.basic as ptr
from pytensor.graph import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
......@@ -91,15 +92,26 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
"""JAX implementation of random variables."""
rv = node.outputs[1]
out_dtype = rv.type.dtype
out_size = rv.type.shape
static_shape = rv.type.shape
batch_ndim = op.batch_ndim(node)
out_size = node.default_output().type.shape[:batch_ndim]
# Try to pass static size directly to JAX
static_size = static_shape[:batch_ndim]
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = node.inputs[1]
if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None
if len(size_tuple):
static_size = tuple(size_param.data)
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if None in out_size:
if None in static_size:
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, dtype, *parameters):
......@@ -111,7 +123,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
else:
def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
return sample_fn
......
......@@ -5,6 +5,7 @@ import scipy.stats as stats
import pytensor
import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr
from pytensor import clone_replace
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
......@@ -26,11 +27,11 @@ jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
def compile_random_function(*args, **kwargs):
def compile_random_function(*args, mode="JAX", **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
return function(*args, **kwargs)
return function(*args, mode=mode, **kwargs)
def test_random_RandomStream():
......@@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input():
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)
def test_constant_shape_after_graph_rewriting():
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)
with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])
# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)
# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论