提交 6d8ba993 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize and update the JAX Op conversion docs

上级 34375f41
Tutorial on adding JAX Ops to Aesara Adding JAX and Numba support for `Op`\s
==================================== =======================================
Aesara is able to convert its graphs into JAX compiled functions. In order to do Aesara is able to convert its graphs into JAX and Numba compiled functions. In order to do
this, each ``Op`` in the graph must have a JAX implementation. This tutorial this, each :class:`Op` in an Aesara graph must have an equivalent JAX/Numba implementation function.
will explain how JAX implementations are created for an ``Op``.
Step 1: Identify the Aesara Op you’d like to JAXify This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will
=================================================== focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
Determine which Aesara Op you’d like supported with JAX and identify the Step 1: Identify the Aesara :class:`Op` you’d like to implement in JAX
function signature and return values. This will come in handy as we need ----------------------------------------------------------------------
to know what we want JAX to do.
Find the source for the Aesara :class:`Op` you’d like to be supported in JAX, and
identify the function signature and return values. These can be determined by
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
with Aesara :class:`Op`\s in order to provide a conversion implementation, so first read
:ref:`extending_aesara` if you are not familiar.
For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows:
.. code:: python
def make_node(self, n, m, k):
n = as_tensor_variable(n)
m = as_tensor_variable(m)
k = as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, broadcastable=(False, False))()],
)
The :class:`Apply` instance that's returned specifies the exact types of inputs that
our JAX implementation will receive and the exact types of outputs it's expected to
return--both in terms of their data types and number of dimensions.
The actual inputs our implementation will receive are necessarily numeric values
or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the
general signature of the underlying computation.
More specifically, the :class:`Apply` implies that the inputs come from values that are
automatically converted to Aesara variables via :func:`as_tensor_variable`, and
the ``assert``\s that follow imply that they must be scalars. According to this
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
implementation must be able to handle all the possible data types.
It also tells us that there's only one return value, that it has a data type
determined by :attr:`Eye.dtype`, and that it has two non-broadcastable
dimensions. The latter implies that the result is necessarily a matrix. The
former implies that our JAX implementation will need to access the :attr:`dtype`
attribute of the Aesara :class:`Eye`\ :class:`Op` it's converting.
Next, we can look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented in JAX.
| Here are the examples for ``eye`` and ``ifelse`` from Aesara from the
compiled doc and codebase respectively
| https://aesara.readthedocs.io/en/latest/library/tensor/basic.html?highlight=eye#aesara.tensor.eye
| https://github.com/aesara-devs/aesara/blob/main/aesara/ifelse.py#L35
Step 2: Find the relevant JAX method (or something close) Step 2: Find the relevant JAX method (or something close)
========================================================= ---------------------------------------------------------
With a precise idea of what the Aesara Op does we need to figure out how With a precise idea of what the Aesara :class:`Op` does we need to figure out how
to implement it in JAX. In easiest scenario JAX has a similarly named to implement it in JAX. In the best case scenario, JAX has a similarly named
method that does the same thing. For example with the ``eye`` operator function that performs exactly the same computations as the :class:`Op`. For
we find the paired ``jax.numpy.eye`` method. example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye`
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_).
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye If we wanted to implement an :class:`Op` like :class:`IfElse`, we might need to
recreate the functionality with some custom logic. In many cases, at least some
custom logic is needed to reformat the inputs and outputs so that they exactly
match the `Op`'s.
For ifelse we’ll need to recreate the functionality with some custom Here's an example for :class:`IfElse`:
logic.
.. code:: python .. code:: python
...@@ -38,118 +82,72 @@ logic. ...@@ -38,118 +82,72 @@ logic.
) )
return res if n_outs > 1 else res[0] return res if n_outs > 1 else res[0]
*Code in context:*
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L583
Step 3: Register the function with the jax_funcify dispatcher Step 3: Register the function with the `jax_funcify` dispatcher
============================================================= ---------------------------------------------------------------
With the Aesara Op replicated in JAX we’ll need to now register this With the Aesara `Op` replicated in JAX, we’ll need to register the
function with the Aesara JAX Linker. This is done through the dispatcher function with the Aesara JAX `Linker`. This is done through the use of
decorator and closure as seen below. If unsure how dispatching works a `singledispatch`. If you don't know how `singledispatch` works, see the
short tutorial on dispatching is at the bottom. `Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.
The linker functions should be added to ``dispatch`` module linked The relevant dispatch functions created by `singledispatch` are :func:`aesara.link.numba.dispatch.numba_funcify` and
below. :func:`aesara.link.jax.dispatch.jax_funcify`.
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py
Here’s an example for the Eye Op. Here’s an example for the `Eye`\ `Op`:
.. code:: python .. code:: python
import jax.numpy as jnp
from aesara.tensor.basic import Eye from aesara.tensor.basic import Eye
from aesara.link.jax.dispatch import jax_funcify
@jax_funcify.register(Eye) # The decorator @jax_funcify.register(Eye)
def jax_funcify_Eye(op): # The function that takes an Op and returns its JAX equivalent def jax_funcify_Eye(op):
# Obtain necessary "static" attributes from the Op being converted
dtype = op.dtype dtype = op.dtype
# Create a JAX jit-able function that implements the Op
def eye(N, M, k): def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype) return jnp.eye(N, M, k, dtype=dtype)
return eye return eye
*Code in context:*
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L1071
Step 4: Write tests Step 4: Write tests
=================== -------------------
Test that your registered Op is working correctly by adding a test to
the ``test_jax.py`` test suite. The test should ensure that Aesara Op,
when included as part of a function graph, passes the tests in
``compare_jax_and_py`` test method. What this test method does is
compile the same function graph in Python and JAX and check that the
numerical output is similar between the JAX and Python output, as well
object types to ensure correct compilation.
https://github.com/aesara-devs/aesara/blob/main/tests/link/test_jax.py
.. code:: python
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = aet.eye(3) # Initialize an Aesara Op
out_fg = aesara.graph.fg.FunctionGraph([], [out]) # Create an Aesara FunctionGraph
compare_jax_and_py(out_fg, []) # Pas the graph and any inputs to testing function Test that your registered `Op` is working correctly by adding tests to the
appropriate test suites in Aesara (e.g. in ``tests.link.test_jax`` and one of
the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
Check the existing tests for the general outline of these kinds of tests. In
most cases, a helper function can be used to easily verify the correspondence
between a JAX/Numba implementation and its `Op`.
*Code in context:* For example, the :func:`compare_jax_and_py` function streamlines the steps
https://github.com/aesara-devs/aesara/blob/056fcee1434818d0aed9234e01c754ed88d0f27a/tests/link/test_jax.py#L250 involved in making comparisons with `Op.perform`.
Step 5: Wait for CI pass and Code Review Here's a small example of a test for :class:`Eye`:
========================================
Create a pull request and ensure CI passes. If it does wait for a code
review and a likely merge!
https://github.com/aesara-devs/aesara/pulls
Appendix: What does singledispatcher do?
========================================
In short a dispatcher figures out what “the right thing” is to do based
on the type of the first argument to the function. It’s easiest
explained with an example. One is provided below in addition to the
python docs.
https://docs.python.org/3/library/functools.html#functools.singledispatch
.. code:: ipython3
from functools import singledispatch
class Cow:
pass
cow = Cow()
class Dog:
pass
dog = Dog()
@singledispatch
def greeting(animal):
print("This animal has not been registered")
@greeting.register(Cow)
def cow_greeting(animal):
print("Mooooo")
@greeting.register(Dog)
def dog_greeting(animal):
print("Woof")
.. code:: python
greeting(cow) import aesara.tensor as aet
greeting(dog)
greeting("A string object")
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
.. parsed-literal:: # Create a symbolic input for `Eye`
x_at = aet.scalar()
Mooooo # Create a variable that is the output of an `Eye` `Op`
Woof eye_var = aet.eye(x_at)
Animal has not been registered
# Create an Aesara `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])
This is what allows the JAX Linker to determine which the correct # Pass the graph and any inputs to the testing function
JAXification Op is as we’ve registered it with the Aesara Op compare_jax_and_py(out_fg, [3])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论