Unverified 提交 b67ff220 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Implement `wrap_jax` and rename `as_op` to `wrap_py` (#1614)

上级 7779b07b
...@@ -208,7 +208,7 @@ jobs: ...@@ -208,7 +208,7 @@ jobs:
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx; micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
......
...@@ -25,4 +25,4 @@ dependencies: ...@@ -25,4 +25,4 @@ dependencies:
- ablog - ablog
- pip - pip
- pip: - pip:
- -e .. - -e ..[jax]
...@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d ...@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d
:download:`Solution<extending_pytensor_solution_1.py>` :download:`Solution<extending_pytensor_solution_1.py>`
:func:`as_op` :func:`wrap_py`
------------- -------------
:func:`as_op` is a Python decorator that converts a Python function into a :func:`wrap_py` is a Python decorator that converts a Python function into a
basic PyTensor :class:`Op` that will call the supplied function during execution. basic PyTensor :class:`Op` that will call the supplied function during execution.
This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation. This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation.
...@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature ...@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
inputs PyTensor variables that were declared. inputs PyTensor variables that were declared.
.. note:: .. note::
The python function wrapped by the :func:`as_op` decorator needs to return a new The python function wrapped by the :func:`wrap_py` decorator needs to return a new
data allocation, no views or in place modification of the input. data allocation, no views or in place modification of the input.
:func:`as_op` Example :func:`wrap_py` Example
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
.. testcode:: asop .. testcode:: asop
...@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature ...@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
import pytensor.tensor as pt import pytensor.tensor as pt
import numpy as np import numpy as np
from pytensor import function from pytensor import function
from pytensor.compile.ops import as_op from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes): def infer_shape_numpy_dot(fgraph, node, input_shapes):
ashp, bshp = input_shapes ashp, bshp = input_shapes
return [ashp[:-1] + bshp[-1:]] return [ashp[:-1] + bshp[-1:]]
@as_op( @wrap_py(
itypes=[pt.dmatrix, pt.dmatrix], itypes=[pt.dmatrix, pt.dmatrix],
otypes=[pt.dmatrix], otypes=[pt.dmatrix],
infer_shape=infer_shape_numpy_dot, infer_shape=infer_shape_numpy_dot,
......
...@@ -167,9 +167,9 @@ class TestSumDiffOp(utt.InferShapeTester): ...@@ -167,9 +167,9 @@ class TestSumDiffOp(utt.InferShapeTester):
import numpy as np import numpy as np
# as_op exercice # wrap_py exercice
import pytensor import pytensor
from pytensor.compile.ops import as_op from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes): def infer_shape_numpy_dot(fgraph, node, input_shapes):
...@@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes): ...@@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes):
return [ashp[:-1] + bshp[-1:]] return [ashp[:-1] + bshp[-1:]]
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_dot, infer_shape=infer_shape_numpy_dot,
...@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes): ...@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes):
return [ashp[0]] return [ashp[0]]
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub, infer_shape=infer_shape_numpy_add_sub,
...@@ -201,7 +201,7 @@ def numpy_add(a, b): ...@@ -201,7 +201,7 @@ def numpy_add(a, b):
return np.add(a, b) return np.add(a, b)
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub, infer_shape=infer_shape_numpy_add_sub,
......
...@@ -61,10 +61,16 @@ Convert to Variable ...@@ -61,10 +61,16 @@ Convert to Variable
.. autofunction:: pytensor.as_symbolic(...) .. autofunction:: pytensor.as_symbolic(...)
Wrap JAX functions
==================
.. autofunction:: wrap_jax(...)
Alias for :func:`pytensor.link.jax.ops.wrap_jax`
Debug Debug
===== =====
.. autofunction:: pytensor.dprint(...) .. autofunction:: pytensor.dprint(...)
Alias for :func:`pytensor.printing.debugprint` Alias for :func:`pytensor.printing.debugprint`
...@@ -166,7 +166,7 @@ from pytensor.scan import checkpoints ...@@ -166,7 +166,7 @@ from pytensor.scan import checkpoints
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.link.jax.ops import wrap_jax
# isort: on # isort: on
......
...@@ -56,6 +56,7 @@ from pytensor.compile.ops import ( ...@@ -56,6 +56,7 @@ from pytensor.compile.ops import (
register_deep_copy_op_c_code, register_deep_copy_op_c_code,
register_view_op_c_code, register_view_op_c_code,
view_op, view_op,
wrap_py,
) )
from pytensor.compile.profiling import ProfileStats from pytensor.compile.profiling import ProfileStats
from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor
""" """
This file contains auxiliary Ops, used during the compilation phase and Ops This file contains auxiliary Ops, used during the compilation phase and Ops
building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that
help make new Ops more rapidly. help make new Ops more rapidly.
""" """
...@@ -268,12 +268,12 @@ class FromFunctionOp(Op): ...@@ -268,12 +268,12 @@ class FromFunctionOp(Op):
obj = load_back(mod, name) obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError): except (ImportError, KeyError, AttributeError):
raise pickle.PicklingError( raise pickle.PicklingError(
f"Can't pickle as_op(), not found as {mod}.{name}" f"Can't pickle wrap_py(), not found as {mod}.{name}"
) )
else: else:
if obj is not self: if obj is not self:
raise pickle.PicklingError( raise pickle.PicklingError(
f"Can't pickle as_op(), not the object at {mod}.{name}" f"Can't pickle wrap_py(), not the object at {mod}.{name}"
) )
return load_back, (mod, name) return load_back, (mod, name)
...@@ -282,6 +282,18 @@ class FromFunctionOp(Op): ...@@ -282,6 +282,18 @@ class FromFunctionOp(Op):
def as_op(itypes, otypes, infer_shape=None): def as_op(itypes, otypes, infer_shape=None):
import warnings
warnings.warn(
"pytensor.as_op is deprecated and will be removed in a future release. "
"Please use pytensor.wrap_py instead.",
FutureWarning,
stacklevel=2,
)
return wrap_py(itypes, otypes, infer_shape)
def wrap_py(itypes, otypes, infer_shape=None):
""" """
Decorator that converts a function into a basic PyTensor op that will call Decorator that converts a function into a basic PyTensor op that will call
the supplied function as its implementation. the supplied function as its implementation.
...@@ -301,7 +313,7 @@ def as_op(itypes, otypes, infer_shape=None): ...@@ -301,7 +313,7 @@ def as_op(itypes, otypes, infer_shape=None):
Examples Examples
-------- --------
@as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], @wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
otypes=[pytensor.tensor.fmatrix]) otypes=[pytensor.tensor.fmatrix])
def numpy_dot(a, b): def numpy_dot(a, b):
return numpy.dot(a, b) return numpy.dot(a, b)
......
...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config ...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph import Constant from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.jax.ops import JAXOp
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
...@@ -142,3 +143,8 @@ def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: ...@@ -142,3 +143,8 @@ def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
return fgraph_fn(*inputs) return fgraph_fn(*inputs)
return opfromgraph return opfromgraph
@jax_funcify.register(JAXOp)
def jax_op_funcify(op, **kwargs):
return op.perform_jax
"""Convert a jax function to a pytensor compatible function."""
from collections.abc import Sequence
from functools import wraps
import numpy as np
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Op, Variable
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.type import TensorType
class JAXOp(Op):
"""
JAXOp is a PyTensor Op that wraps a JAX function, providing both forward
computation and reverse-mode differentiation (via VJP).
Parameters
----------
input_types : list
A list of PyTensor types for each input variable.
output_types : list
A list of PyTensor types for each output variable.
jax_function : callable
The JAX function that computes outputs from inputs. It should
always return a tuple of outputs, even if there is only one output.
name : str, optional
A custom name for the Op instance. If provided, the class name will be
updated accordingly.
Example
-------
This example defines a simple function that sums the input array with a dynamic shape.
>>> import numpy as np
>>> import jax
>>> import jax.numpy as jnp
>>> from pytensor.tensor import TensorType
>>>
>>> # Create the jax function that sums the input array.
>>> def sum_function(x, y):
... return (jnp.sum(x + y),)
>>>
>>> # Create the input and output types, input has a dynamic shape.
>>> input_type = TensorType("float32", shape=(None,))
>>> output_type = TensorType("float32", shape=())
>>>
>>> # Instantiate a JAXOp
>>> op = JAXOp(
... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp"
... )
>>> # Define symbolic input variables.
>>> x = pt.tensor("x", dtype="float32", shape=(2,))
>>> y = pt.tensor("y", dtype="float32", shape=(2,))
>>> # Compile a PyTensor function.
>>> result = op(x, y)
>>> f = pytensor.function([x, y], [result])
>>> print(
... f(
... np.array([2.0, 3.0], dtype=np.float32),
... np.array([4.0, 5.0], dtype=np.float32),
... )
... )
[array(14., dtype=float32)]
>>>
>>> # Compute the gradient of op(x, y) with respect to x.
>>> g = pt.grad(result, x)
>>> grad_f = pytensor.function([x, y], [g])
>>> print(
... grad_f(
... np.array([2.0, 3.0], dtype=np.float32),
... np.array([4.0, 5.0], dtype=np.float32),
... )
... )
[array([1., 1.], dtype=float32)]
"""
__props__ = ("input_types", "output_types", "jax_func")
def __init__(self, input_types, output_types, jax_function, name=None):
import jax
self.input_types = tuple(input_types)
self.output_types = tuple(output_types)
self.jax_func = jax_function
self.jitted_func = jax.jit(jax_function)
self.name = name
super().__init__()
def __repr__(self):
base = self.__class__.__name__
props = list(self.__props__)
if self.name is not None:
props.insert(0, "name")
props = ", ".join(f"{prop}={getattr(self, prop)}" for prop in props)
return f"{base}({props})"
def make_node(self, *inputs: Variable) -> Apply:
"""Create an Apply node with the given inputs and inferred outputs."""
if len(inputs) != len(self.input_types):
raise ValueError(
f"Op {self} expected {len(self.input_types)} inputs, got {len(inputs)}"
)
filtered_inputs = [
inp_type.filter_variable(inp)
for inp, inp_type in zip(inputs, self.input_types)
]
outputs = [output_type() for output_type in self.output_types]
return Apply(self, filtered_inputs, outputs)
def perform(self, node, inputs, outputs):
"""Execute the JAX function and store results in output storage."""
results = self.jitted_func(*inputs)
if not isinstance(results, tuple):
raise TypeError("JAX function must return a tuple of outputs.")
if len(results) != len(outputs):
raise ValueError(
f"JAX function returned {len(results)} outputs, but "
f"{len(outputs)} were expected."
)
for output_container, result, out_type in zip(
outputs, results, self.output_types
):
output_container[0] = np.array(result, dtype=out_type.dtype)
def perform_jax(self, *inputs):
"""Execute the JAX function directly, returning JAX arrays."""
outputs = self.jitted_func(*inputs)
if not isinstance(outputs, tuple):
raise TypeError("JAX function must return a tuple of outputs.")
if len(outputs) == 1:
return outputs[0]
return outputs
def grad(self, inputs, output_gradients):
"""Compute gradients using JAX's vector-Jacobian product (VJP)."""
import jax
# Find indices of outputs that need gradients
connected_output_indices = [
i
for i, output_grad in enumerate(output_gradients)
if not isinstance(output_grad.type, DisconnectedType)
]
num_inputs = len(inputs)
def vjp_operation(*args):
"""VJP operation that computes gradients w.r.t. inputs."""
input_values = args[:num_inputs]
cotangent_vectors = args[num_inputs:]
assert len(cotangent_vectors) == len(connected_output_indices)
def restricted_function(*input_values):
"""Restricted function that only returns connected outputs."""
outputs = self.jax_func(*input_values)
return [
outputs[i].astype(self.output_types[i].dtype)
for i in connected_output_indices
]
_primals, vjp_function = jax.vjp(restricted_function, *input_values)
output_dtypes = [
self.output_types[i].dtype for i in connected_output_indices
]
return vjp_function(
[
cotangent.astype(dtype)
for cotangent, dtype in zip(
cotangent_vectors, output_dtypes, strict=True
)
]
)
if self.name is not None:
name = "vjp_" + self.name
else:
name = "vjp_jax_op"
# Create VJP operation
vjp_op = JAXOp(
self.input_types
+ tuple(self.output_types[i] for i in connected_output_indices),
[self.input_types[i] for i in range(num_inputs)],
vjp_operation,
name=name,
)
return vjp_op(
*[*inputs, *[output_gradients[i] for i in connected_output_indices]],
return_list=True,
)
def wrap_jax(jax_function=None, *, allow_eval=True):
"""Return a PyTensor-compatible function from a JAX jittable function.
This decorator wraps a JAX function so that it accepts and returns
`pytensor.Variable` objects. The JAX-jittable function can accept any
nested Python structure (a `Pytree
<https://jax.readthedocs.io/en/latest/pytrees.html>`_) as input, and might
return any nested Python structure.
Parameters
----------
jax_function : Callable, optional
A JAX function to be wrapped. If None, returns a decorator function.
allow_eval : bool, default=True
Whether to allow evaluation of symbolic shapes when input shapes are
not fully determined.
Returns
-------
Callable
A function that wraps the given JAX function so that it can be called with
pytensor.Variable inputs and returns pytensor.Variable outputs.
Examples
--------
>>> import jax.numpy as jnp
>>> import pytensor.tensor as pt
>>> from pytensor import wrap_jax
>>> @wrap_jax
... def add(x, y):
... return jnp.add(x, y)
>>> x = pt.scalar("x")
>>> y = pt.scalar("y")
>>> result = add(x, y)
>>> f = pytensor.function([x, y], [result])
>>> print(f(1, 2))
[array(3.)]
We can also pass arbitrary jax pytree structures as inputs and outputs:
>>> import jax
>>> import jax.numpy as jnp
>>> import pytensor.tensor as pt
>>> from pytensor import wrap_jax
>>> @wrap_jax
... def complex_function(x, y, scale=1.0):
... return {
... "sum": jnp.add(x, y) * scale,
... }
>>> x = pt.vector("x", shape=(3,))
>>> y = pt.vector("y", shape=(3,))
>>> result = complex_function(x, y, scale=2.0)
>>> f = pytensor.function([x, y], [result["sum"]])
Or Equinox modules:
>>> x = pt.tensor("x", shape=(3,)) # doctest +SKIP
>>> y = pt.tensor("y", shape=(3,)) # doctest +SKIP
>>> import equinox as eqx # doctest +SKIP
>>> mlp = eqx.nn.MLP(
... 3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0)
... ) # doctest +SKIP
>>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) # doctest +SKIP
>>> @wrap_jax # doctest +SKIP
... def neural_network(x, mlp): # doctest +SKIP
... return mlp(x) # doctest +SKIP
>>> out = neural_network(x, mlp) # doctest +SKIP
If the input shapes are not fully determined, and valid
input shapes cannot be inferred by evaluating the inputs either,
an error will be raised:
>>> import jax.numpy as jnp
>>> import pytensor.tensor as pt
>>> @wrap_jax
... def add(x, y):
... return jnp.add(x, y)
>>> x = pt.vector("x") # shape is not fully determined
>>> y = pt.vector("y") # shape is not fully determined
>>> result = add(x, y)
ValueError: Could not compile a function to infer example shapes. Please provide inputs with fully determined shapes by calling pt.specify_shape.
...
"""
def decorator(func):
name = func.__name__
try:
import jax
except ImportError as e:
raise ImportError(
"The wrap_jax decorator requires jax to be installed."
) from e
@wraps(func)
def wrapper(*args, **kwargs):
# Partition inputs into dynamic PyTensor variables and static variables.
# Static variables don't participate in the computational graph.
pytensor_variables, static_values = _eqx_partition(
(args, kwargs), lambda x: isinstance(x, Variable)
)
# Flatten the PyTensor variables for processing
variables_flat, variables_treedef = jax.tree.flatten(pytensor_variables)
input_types = [var.type for var in variables_flat]
# Determine output types by calling the function through jax.eval_shape
output_types, output_treedef, output_static = _find_output_types(
func,
variables_flat,
variables_treedef,
static_values,
allow_eval=allow_eval,
)
def flattened_function(*flat_variables):
"""Execute the original function with flattened inputs."""
variables = jax.tree.unflatten(variables_treedef, flat_variables)
reconstructed_args, reconstructed_kwargs = _eqx_combine(
variables, static_values
)
function_outputs = func(*reconstructed_args, **reconstructed_kwargs)
array_outputs, _ = _eqx_partition(function_outputs, _is_array)
flattened_outputs, _ = jax.tree.flatten(array_outputs)
return tuple(flattened_outputs)
# Create the JAX operation
jax_op_instance = JAXOp(
input_types,
output_types,
flattened_function,
name=name,
)
# Execute the operation and reconstruct the output structure
flattened_results = jax_op_instance(*variables_flat)
if not isinstance(flattened_results, Sequence):
flattened_results = [flattened_results]
output_variables = jax.tree.unflatten(output_treedef, flattened_results)
final_outputs = _eqx_combine(output_variables, output_static)
return final_outputs
return wrapper
if jax_function is None:
return decorator
else:
return decorator(jax_function)
def _find_output_types(
jax_function, inputs_flat, input_treedef, static_input, *, allow_eval=True
):
"""Determine output types with jax.eval_shape on dummy inputs."""
import jax
import jax.numpy as jnp
resolved_input_shapes = []
requires_shape_evaluation = False
for variable in inputs_flat:
# If shape is already fully determined, use it directly
if not any(dimension is None for dimension in variable.type.shape):
resolved_input_shapes.append(variable.type.shape)
continue
# Try to infer static shape
_, inferred_shape = infer_static_shape(variable.shape)
if not any(dimension is None for dimension in inferred_shape):
resolved_input_shapes.append(inferred_shape)
continue
# Shape still has undetermined dimensions
if not allow_eval:
raise ValueError(
f"Input variable {variable} has undetermined shape dimensions. "
"Please provide inputs with fully determined shapes by calling "
"pt.specify_shape."
)
requires_shape_evaluation = True
resolved_input_shapes.append(variable.shape)
if requires_shape_evaluation:
try:
shape_evaluation_function = function(
[],
resolved_input_shapes,
on_unused_input="ignore",
mode=Mode(linker="py", optimizer="fast_compile"),
)
except Exception as e:
raise ValueError(
"Could not compile a function to infer example shapes. "
"Please provide inputs with fully determined shapes by "
"calling pt.specify_shape."
) from e
resolved_input_shapes = shape_evaluation_function()
# Determine output types using jax.eval_shape with dummy inputs
output_metadata_storage = {}
dummy_input_arrays = [
jnp.ones(shape, dtype=variable.type.dtype)
for variable, shape in zip(inputs_flat, resolved_input_shapes, strict=True)
]
def wrapped_jax_function(input_arrays):
"""Wrapper to extract output metadata during shape evaluation."""
variables = jax.tree.unflatten(input_treedef, input_arrays)
reconstructed_args, reconstructed_kwargs = _eqx_combine(variables, static_input)
function_outputs = jax_function(*reconstructed_args, **reconstructed_kwargs)
array_outputs, static_outputs = _eqx_partition(function_outputs, _is_array)
# Store metadata for later use
output_metadata_storage["output_static"] = static_outputs
flattened_outputs, output_structure = jax.tree.flatten(array_outputs)
output_metadata_storage["output_treedef"] = output_structure
return flattened_outputs
output_shapes_flat = jax.eval_shape(wrapped_jax_function, dummy_input_arrays)
output_treedef = output_metadata_storage["output_treedef"]
output_static = output_metadata_storage["output_static"]
# If we used shape evaluation, set all output shapes to unknown
if requires_shape_evaluation:
output_types = [
TensorType(
dtype=output_shape.dtype, shape=tuple(None for _ in output_shape.shape)
)
for output_shape in output_shapes_flat
]
else:
output_types = [
TensorType(dtype=output_shape.dtype, shape=output_shape.shape)
for output_shape in output_shapes_flat
]
return output_types, output_treedef, output_static
# From the equinox library, licensed under Apache 2.0
# https://github.com/patrick-kidger/equinox
#
# Copied here to avoid a dependency on equinox just these functions.
def _eqx_combine(*pytrees, is_leaf=None):
"""Combines multiple PyTrees into one PyTree, by replacing `None` leaves.
!!! example
```python
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
```
!!! tip
The idea is that `equinox.combine` should be used to undo a call to
[`equinox.filter`][] or [`equinox.partition`][].
**Arguments:**
- `*pytrees`: a sequence of PyTrees all with the same structure.
- `is_leaf`: As [`equinox.partition`][].
**Returns:**
A PyTree with the same structure as its inputs. Each leaf will be the first
non-`None` leaf found in the corresponding leaves of `pytrees` as they are
iterated over.
"""
import jax
if is_leaf is None:
_is_leaf = _is_none
else:
_is_leaf = lambda x: _is_none(x) or is_leaf(x) # noqa: E731
return jax.tree.map(_combine, *pytrees, is_leaf=_is_leaf)
def _eqx_partition(
pytree,
filter_spec,
replace=None,
is_leaf=None,
):
"""Splits a PyTree into two pieces. Equivalent to
`filter(...), filter(..., inverse=True)`, but slightly more efficient.
!!! info
See also [`equinox.combine`][] to reconstitute the PyTree again.
"""
import jax
filter_tree = jax.tree.map(_make_filter_tree(is_leaf), filter_spec, pytree)
left = jax.tree.map(lambda mask, x: x if mask else replace, filter_tree, pytree)
right = jax.tree.map(lambda mask, x: replace if mask else x, filter_tree, pytree)
return left, right
def _make_filter_tree(is_leaf):
import jax
import jax.core
def _filter_tree(mask, arg):
if isinstance(mask, jax.core.Tracer):
raise ValueError("`filter_spec` leaf values cannot be traced arrays.")
if isinstance(mask, bool):
return jax.tree.map(lambda _: mask, arg, is_leaf=is_leaf)
elif callable(mask):
return jax.tree.map(mask, arg, is_leaf=is_leaf)
else:
raise ValueError(
"`filter_spec` must consist of booleans and callables only."
)
return _filter_tree
def _is_array(element) -> bool:
"""Returns `True` if `element` is a JAX array or NumPy array."""
import jax
return isinstance(element, np.ndarray | np.generic | jax.Array)
def _combine(*args):
for arg in args:
if arg is not None:
return arg
return None
def _is_none(x):
return x is None
import pickle import pickle
import numpy as np import numpy as np
import pytest
from pytensor import function from pytensor import function
from pytensor.compile.ops import as_op from pytensor.compile.ops import as_op, wrap_py
from pytensor.tensor.type import dmatrix, dvector from pytensor.tensor.type import dmatrix, dvector
from tests import unittest_tools as utt from tests import unittest_tools as utt
@as_op([dmatrix, dmatrix], dmatrix) @wrap_py([dmatrix, dmatrix], dmatrix)
def mul(a, b): def mul(a, b):
""" """
This is for test_pickle, since the function still has to be This is for test_pickle, since the function still has to be
...@@ -21,6 +22,21 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -21,6 +22,21 @@ class TestOpDecorator(utt.InferShapeTester):
def test_1arg(self): def test_1arg(self):
x = dmatrix("x") x = dmatrix("x")
@wrap_py(dmatrix, dvector)
def cumprod(x):
return np.cumprod(x)
fn = function([x], cumprod(x))
r = fn([[1.5, 5], [2, 2]])
r0 = np.array([1.5, 7.5, 15.0, 30.0])
assert np.allclose(r, r0), (r, r0)
def test_deprecation(self):
x = dmatrix("x")
with pytest.warns(FutureWarning):
@as_op(dmatrix, dvector) @as_op(dmatrix, dvector)
def cumprod(x): def cumprod(x):
return np.cumprod(x) return np.cumprod(x)
...@@ -37,7 +53,7 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -37,7 +53,7 @@ class TestOpDecorator(utt.InferShapeTester):
y = dvector("y") y = dvector("y")
y.tag.test_value = [0, 0, 0, 0] y.tag.test_value = [0, 0, 0, 0]
@as_op([dmatrix, dvector], dvector) @wrap_py([dmatrix, dvector], dvector)
def cumprod_plus(x, y): def cumprod_plus(x, y):
return np.cumprod(x) + y return np.cumprod(x) + y
...@@ -57,7 +73,7 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -57,7 +73,7 @@ class TestOpDecorator(utt.InferShapeTester):
x, y = shapes x, y = shapes
return [y] return [y]
@as_op([dmatrix, dvector], dvector, infer_shape) @wrap_py([dmatrix, dvector], dvector, infer_shape)
def cumprod_plus(x, y): def cumprod_plus(x, y):
return np.cumprod(x) + y return np.cumprod(x) + y
......
import numpy as np
import pytest
from pytensor import config, grad, wrap_jax
from pytensor.compile.sharedvalue import shared
from pytensor.link.jax.ops import JAXOp
from pytensor.scalar import all_types
from pytensor.tensor import TensorType, tensor
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def test_two_inputs_single_output():
rng = np.random.default_rng(1)
x = tensor("x", shape=(2,))
y = tensor("y", shape=(2,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
def f(x, y):
return jax.nn.sigmoid(x + y)
# Test with wrap_jax decorator
out = wrap_jax(f)(x, y)
grad_out = grad(out.sum(), [x, y])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def f(x, y):
return (jax.nn.sigmoid(x + y),)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,))],
f,
)
out = jax_op(x, y)
grad_out = grad(out.sum(), [x, y])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def test_op_returns_list():
x = tensor("x", shape=(2,))
y = tensor("y", shape=(2,))
test_values = [np.ones((2,)).astype(config.floatX) for inp in (x, y)]
def f(x, y):
return jax.nn.sigmoid(x + y)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,))],
f,
)
with pytest.raises(TypeError, match="tuple of outputs"):
out = jax_op(x, y)
grad_out = grad(out.sum(), [x, y])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def test_two_inputs_tuple_output():
rng = np.random.default_rng(2)
x = tensor("x", shape=(2,))
y = tensor("y", shape=(2,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
def f(x, y):
return jax.nn.sigmoid(x + y), y * 2
# Test with wrap_jax decorator
out1, out2 = wrap_jax(f)(x, y)
grad_out = grad((out1 + out2).sum(), [x, y])
compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values)
with jax.disable_jit():
# must_be_device_array is False, because the with disabled jit compilation,
# inputs are not automatically transformed to jax.Array anymore
compare_jax_and_py(
[x, y], [out1, out2, *grad_out], test_values, must_be_device_array=False
)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out1, out2 = jax_op(x, y)
grad_out = grad((out1 + out2).sum(), [x, y])
compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values)
def test_two_inputs_list_output_one_unused_output():
# One output is unused, to test whether the wrapper can handle DisconnectedType
rng = np.random.default_rng(3)
x = tensor("x", shape=(2,))
y = tensor("y", shape=(2,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
def f(x, y):
return (jax.nn.sigmoid(x + y), y * 2)
# Test with wrap_jax decorator
out, _ = wrap_jax(f)(x, y)
grad_out = grad(out.sum(), [x, y])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out, *grad_out], test_values)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out, _ = jax_op(x, y)
grad_out = grad(out.sum(), [x, y])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def test_single_input_tuple_output():
rng = np.random.default_rng(4)
x = tensor("x", shape=(2,))
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
def f(x):
return jax.nn.sigmoid(x), x * 2
# Test with wrap_jax decorator
out1, out2 = wrap_jax(f)(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py(
[x], [out1, out2, *grad_out], test_values, must_be_device_array=False
)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type],
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
def test_scalar_input_tuple_output():
rng = np.random.default_rng(5)
x = tensor("x", shape=())
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
def f(x):
return jax.nn.sigmoid(x), x
# Test with wrap_jax decorator
out1, out2 = wrap_jax(f)(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py(
[x], [out1, out2, *grad_out], test_values, must_be_device_array=False
)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type],
[TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
def test_single_input_list_output():
rng = np.random.default_rng(6)
x = tensor("x", shape=(2,))
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
def f(x):
return (jax.nn.sigmoid(x), 2 * x)
# Test with wrap_jax decorator
out1, out2 = wrap_jax(f)(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py(
[x], [out1, out2, *grad_out], test_values, must_be_device_array=False
)
# Test direct JAXOp usage, with unspecified output shapes
jax_op = JAXOp(
[x.type],
[
TensorType(config.floatX, shape=(None,)),
TensorType(config.floatX, shape=(None,)),
],
f,
)
out1, out2 = jax_op(x)
grad_out = grad(out1.sum(), [x])
compare_jax_and_py([x], [out1, out2, *grad_out], test_values)
def test_pytree_input_tuple_output():
rng = np.random.default_rng(7)
x = tensor("x", shape=(2,))
y = tensor("y", shape=(2,))
y_tmp = {"y": y, "y2": [y**2]}
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
@wrap_jax
def f(x, y):
return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0]
# Test with wrap_jax decorator
out = f(x, y_tmp)
grad_out = grad(out[1].sum(), [x, y])
compare_jax_and_py([x, y], [out[0], out[1], *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py(
[x, y], [out[0], out[1], *grad_out], test_values, must_be_device_array=False
)
def test_pytree_input_pytree_output():
rng = np.random.default_rng(8)
x = tensor("x", shape=(3,))
y = tensor("y", shape=(1,))
y_tmp = {"a": y, "b": [y**2]}
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
@wrap_jax
def f(x, y):
return x, jax.tree_util.tree_map(lambda x: jax.numpy.exp(x), y)
# Test with wrap_jax decorator
out = f(x, y_tmp)
grad_out = grad(out[1]["b"][0].sum(), [x, y])
compare_jax_and_py([x, y], [out[0], out[1]["a"], *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py(
[x, y],
[out[0], out[1]["a"], *grad_out],
test_values,
must_be_device_array=False,
)
def test_pytree_input_with_non_graph_args():
rng = np.random.default_rng(9)
x = tensor("x", shape=(3,))
y = tensor("y", shape=(1,))
y_tmp = {"a": y, "b": [y**2]}
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
@wrap_jax
def f(x, y, depth, which_variable):
if which_variable == "x":
var = x
elif which_variable == "y":
var = y["a"] + y["b"][0]
else:
return "Unsupported argument"
for _ in range(depth):
var = jax.nn.sigmoid(var)
return var
# Test with wrap_jax decorator
# arguments depth and which_variable are not part of the graph
out = f(x, y_tmp, depth=3, which_variable="x")
grad_out = grad(out.sum(), [x])
compare_jax_and_py([x, y], [out[0], *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out[0], *grad_out], test_values)
out = f(x, y_tmp, depth=7, which_variable="y")
grad_out = grad(out.sum(), [x])
compare_jax_and_py([x, y], [out[0], *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out[0], *grad_out], test_values)
out = f(x, y_tmp, depth=10, which_variable="z")
assert out == "Unsupported argument"
def test_unused_matrix_product():
# A matrix output is unused, to test whether the wrapper can handle a
# DisconnectedType with a larger dimension.
rng = np.random.default_rng(10)
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
def f(x, y):
return x[:, None] @ y[None], jax.numpy.exp(x)
# Test with wrap_jax decorator
out = wrap_jax(f)(x, y)
grad_out = grad(out[1].sum(), [x])
compare_jax_and_py([x, y], [out[1], *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out[1], *grad_out], test_values)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[
TensorType(config.floatX, shape=(3, 3)),
TensorType(config.floatX, shape=(3,)),
],
f,
)
out = jax_op(x, y)
grad_out = grad(out[1].sum(), [x])
compare_jax_and_py([x, y], [out[1], *grad_out], test_values)
def test_unknown_static_shape():
rng = np.random.default_rng(11)
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
x_cumsum = x.cumsum() # Now x_cumsum has an unknown shape
def f(x, y):
return (x * jax.numpy.ones(3),)
(out,) = wrap_jax(f)(x_cumsum, y)
grad_out = grad(out.sum(), [x])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out, *grad_out], test_values)
# Test direct JAXOp usage
jax_op = JAXOp(
[x.type, y.type],
[TensorType(config.floatX, shape=(None,))],
f,
)
out = jax_op(x_cumsum, y)
grad_out = grad(out.sum(), [x])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def test_nn():
eqx = pytest.importorskip("equinox")
nn = pytest.importorskip("equinox.nn")
rng = np.random.default_rng(13)
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
]
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
mlp = nn.MLP(3, 3, 3, depth=2, activation=jax.numpy.tanh, key=jax.random.key(0))
mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y)
@wrap_jax
def f(x, mlp):
return mlp(x)
out = f(x, mlp)
grad_out = grad(out.sum(), [x])
compare_jax_and_py([x, y], [out, *grad_out], test_values)
with jax.disable_jit():
compare_jax_and_py([x, y], [out, *grad_out], test_values)
def test_no_inputs():
def f():
return jax.numpy.array(42.0)
out = wrap_jax(f)()
assert out.eval() == 42.0
def test_unknown_shape():
x = tensor("x", shape=(None,))
def f(x):
return x * 2
with pytest.raises(ValueError, match="Please provide inputs"):
wrap_jax(f)(x)
def test_unknown_shape_with_eval():
x = shared(np.ones(3))
assert x.type.shape == (None,)
def f(x):
return x * 2
out = wrap_jax(f)(x)
grad_out = grad(out.sum(), [x])
compare_jax_and_py([], [out, *grad_out], [])
with jax.disable_jit():
compare_jax_and_py([], [out, *grad_out], [], must_be_device_array=False)
with pytest.raises(ValueError, match="Please provide inputs"):
wrap_jax(f, allow_eval=False)(x)
def test_decorator_forms():
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
@wrap_jax
def the_name1(x, y):
return (x + y).sum()
@wrap_jax(allow_eval=True)
def the_name2(x, y):
return (x + y).sum()
the_name1(x, y)
the_name2(x, y)
def test_repr():
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
def the_name(x, y):
return (x + y).sum()
jax_op = wrap_jax(the_name)
assert "the_name" in repr(jax_op(x, y).owner.op)
(grad_x, _) = grad(jax_op(x, y), [x, y])
assert "vjp_the_name" in repr(grad_x.owner.op)
class TestDtypes:
@pytest.mark.parametrize("in_dtype", list(map(str, all_types)))
@pytest.mark.parametrize("out_dtype", list(map(str, all_types)))
def test_different_in_output(self, in_dtype, out_dtype):
x = tensor("x", shape=(3,), dtype=in_dtype)
y = tensor("y", shape=(3,), dtype=in_dtype)
if "int" in in_dtype:
test_values = [
np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype)
for inp in (x, y)
]
else:
test_values = [
np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype)
for inp in (x, y)
]
@wrap_jax
def f(x, y):
out = jax.numpy.add(x, y)
return jax.numpy.real(out).astype(out_dtype)
out = f(x, y)
assert out.dtype == out_dtype
if "float" in in_dtype and "float" in out_dtype:
grad_out = grad(out[0], [x, y])
assert grad_out[0].dtype == in_dtype
compare_jax_and_py([x, y], [out, *grad_out], test_values)
else:
compare_jax_and_py([x, y], [out], test_values)
with jax.disable_jit():
if "float" in in_dtype and "float" in out_dtype:
compare_jax_and_py([x, y], [out, *grad_out], test_values)
else:
compare_jax_and_py([x, y], [out], test_values)
@pytest.mark.parametrize("in1_dtype", list(map(str, all_types)))
@pytest.mark.parametrize("in2_dtype", list(map(str, all_types)))
def test_test_different_inputs(self, in1_dtype, in2_dtype):
x = tensor("x", shape=(3,), dtype=in1_dtype)
y = tensor("y", shape=(3,), dtype=in2_dtype)
if "int" in in1_dtype:
test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)]
else:
test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)]
if "int" in in2_dtype:
test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype))
else:
test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype))
@wrap_jax
def f(x, y):
out = jax.numpy.add(x, y)
return jax.numpy.real(out).astype(in1_dtype)
out = f(x, y)
assert out.dtype == in1_dtype
if "float" in in1_dtype and "float" in in2_dtype:
# In principle, the gradient should also be defined if the second input is
# an integer, but it doesn't work for some reason.
grad_out = grad(out[0], [x])
assert grad_out[0].dtype == in1_dtype
inputs = [x, y]
outputs = [out, *grad_out]
else:
inputs = [x, y]
outputs = [out]
compare_jax_and_py(inputs, outputs, test_values)
with jax.disable_jit():
if "float" in in1_dtype and "float" in in2_dtype:
compare_jax_and_py([x, y], [out, *grad_out], test_values)
else:
compare_jax_and_py([x, y], [out], test_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论