提交 61c40a8d authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Thomas Wiecki

Rewrite `size` input of `RandomVariable`s in JAX backend

上级 0d1f65f8
......@@ -449,7 +449,7 @@ else:
JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
)
NUMBA = Mode(
NumbaLinker(),
......
......@@ -8,6 +8,7 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
import pytensor.tensor.random.basic as aer
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
......@@ -28,7 +29,7 @@ or the shape of an array:
def assert_size_argument_jax_compatible(node):
"""Assert whether the current node can be compiled.
"""Assert whether the current node can be JIT-compiled by JAX.
JAX can JIT-compile `jax.random` functions when the `size` argument
is a concrete value, i.e. either a constant or the shape of any
......@@ -37,7 +38,7 @@ def assert_size_argument_jax_compatible(node):
"""
size = node.inputs[1]
size_op = size.owner.op
if not isinstance(size_op, (Shape, Shape_i)):
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
......
import jax.numpy as jnp
from pytensor.graph import Constant
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from pytensor.tensor.type import TensorType
class JAXShapeTuple(Op):
"""Dummy Op that represents a `size` specified as a tuple."""
def make_node(self, *inputs):
dtype = inputs[0].type.dtype
otype = TensorType(dtype, shape=(len(inputs),))
return Apply(self, inputs, [otype()])
def perform(self, *inputs):
return tuple(inputs)
@jax_funcify.register(JAXShapeTuple)
def jax_funcify_JAXShapeTuple(op, **kwargs):
def shape_tuple_fn(*x):
return tuple(x)
return shape_tuple_fn
@jax_funcify.register(Reshape)
......
# TODO: This is for backward-compatibility; remove when reasonable.
from pytensor.tensor.random.rewriting.basic import *
# isort: off
# Register JAX specializations
import pytensor.tensor.random.rewriting.jax
# isort: on
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.op import RandomVariable
@node_rewriter([RandomVariable])
def size_parameter_as_tuple(fgraph, node):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `size` or `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
size_arg = node.inputs[1]
size_node = size_arg.owner
if size_node is None:
return
if isinstance(size_node.op, JAXShapeTuple):
return
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_broadcastable == ()
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
new_size_args = JAXShapeTuple()(*size_node.inputs)
new_inputs = list(node.inputs)
new_inputs[1] = new_size_args
new_node = node.clone_with_new_inputs(new_inputs)
return new_node.outputs
optdb.register(
"jax_size_parameter_as_tuple", in2out(size_parameter_as_tuple), "jax", position=100
)
......@@ -27,7 +27,7 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
......
......@@ -454,8 +454,18 @@ def test_random_concrete_shape():
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
@pytest.mark.xfail(reason="size argument specified as a tuple is a `DimShuffle` node")
def test_random_concrete_shape_subtensor():
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
......@@ -463,8 +473,15 @@ def test_random_concrete_shape_subtensor():
assert jax_fn(np.ones((2, 3))).shape == (3,)
@pytest.mark.xfail(reason="size argument specified as a tuple is a `MakeVector` node")
def test_random_concrete_shape_subtensor_tuple():
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
......@@ -472,7 +489,9 @@ def test_random_concrete_shape_subtensor_tuple():
assert jax_fn(np.ones((2, 3))).shape == (2,)
@pytest.mark.xfail(reason="`size_at` should be specified as a static argument")
@pytest.mark.xfail(
reason="`size_at` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论