Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6d8ba993
提交
6d8ba993
authored
12月 11, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Generalize and update the JAX Op conversion docs
上级
34375f41
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
107 行增加
和
109 行删除
+107
-109
jax_op.rst
doc/extending/jax_op.rst
+107
-109
没有找到文件。
doc/extending/jax_op.rst
浏览文件 @
6d8ba993
Tutorial on adding JAX Ops to Aesara
Adding JAX and Numba support for `Op`\s
====================================
====================================
===
Aesara is able to convert its graphs into JAX compiled functions. In order to do
Aesara is able to convert its graphs into JAX and Numba compiled functions. In order to do
this, each ``Op`` in the graph must have a JAX implementation. This tutorial
this, each :class:`Op` in an Aesara graph must have an equivalent JAX/Numba implementation function.
will explain how JAX implementations are created for an ``Op``.
Step 1: Identify the Aesara Op you’d like to JAXify
This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will
===================================================
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
Determine which Aesara Op you’d like supported with JAX and identify the
Step 1: Identify the Aesara :class:`Op` you’d like to implement in JAX
function signature and return values. This will come in handy as we need
----------------------------------------------------------------------
to know what we want JAX to do.
| Here are the examples for ``eye`` and ``ifelse`` from Aesara from the
Find the source for the Aesara :class:`Op` you’d like to be supported in JAX, and
compiled doc and codebase respectively
identify the function signature and return values. These can be determined by
| https://aesara.readthedocs.io/en/latest/library/tensor/basic.html?highlight=eye#aesara.tensor.eye
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
| https://github.com/aesara-devs/aesara/blob/main/aesara/ifelse.py#L35
with Aesara :class:`Op`\s in order to provide a conversion implementation, so first read
:ref:`extending_aesara` if you are not familiar.
Step 2: Find the relevant JAX method (or something close)
For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows:
=========================================================
With a precise idea of what the Aesara Op does we need to figure out how
to implement it in JAX. In easiest scenario JAX has a similarly named
method that does the same thing. For example with the ``eye`` operator
we find the paired ``jax.numpy.eye`` method.
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye
For ifelse we’ll need to recreate the functionality with some custom
logic.
.. code:: python
.. code:: python
def ifelse(cond, *args, n_outs=n_outs):
def make_node(self, n, m, k):
res = jax.lax.cond(
n = as_tensor_variable(n)
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
m = as_tensor_variable(m)
k = as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, broadcastable=(False, False))()],
)
)
return res if n_outs > 1 else res[0]
*Code in context:*
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L583
Step 3: Register the function with the jax_funcify dispatcher
The :class:`Apply` instance that's returned specifies the exact types of inputs that
=============================================================
our JAX implementation will receive and the exact types of outputs it's expected to
return--both in terms of their data types and number of dimensions.
The actual inputs our implementation will receive are necessarily numeric values
or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the
general signature of the underlying computation.
With the Aesara Op replicated in JAX we’ll need to now register this
More specifically, the :class:`Apply` implies that the inputs come from values that are
function with the Aesara JAX Linker. This is done through the dispatcher
automatically converted to Aesara variables via :func:`as_tensor_variable`, and
decorator and closure as seen below. If unsure how dispatching works a
the ``assert``\s that follow imply that they must be scalars. According to this
short tutorial on dispatching is at the bottom.
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
implementation must be able to handle all the possible data types.
The linker functions should be added to ``dispatch`` module linked
It also tells us that there's only one return value, that it has a data type
below.
determined by :attr:`Eye.dtype`, and that it has two non-broadcastable
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py
dimensions. The latter implies that the result is necessarily a matrix. The
former implies that our JAX implementation will need to access the :attr:`dtype`
attribute of the Aesara :class:`Eye`\ :class:`Op` it's converting.
Here’s an example for the Eye Op.
Next, we can look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented in JAX.
.. code:: python
from aesara.tensor.basic import Eye
@jax_funcify.register(Eye) # The decorator
Step 2: Find the relevant JAX method (or something close)
def jax_funcify_Eye(op): # The function that takes an Op and returns its JAX equivalent
---------------------------------------------------------
dtype = op.dtype
def eye(N, M, k):
With a precise idea of what the Aesara :class:`Op` does we need to figure out how
return jnp.eye(N, M, k, dtype=dtype)
to implement it in JAX. In the best case scenario, JAX has a similarly named
function that performs exactly the same computations as the :class:`Op`. For
example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye`
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_).
return eye
If we wanted to implement an :class:`Op` like :class:`IfElse`, we might need to
recreate the functionality with some custom logic. In many cases, at least some
custom logic is needed to reformat the inputs and outputs so that they exactly
match the `Op`'s.
*Code in context:*
Here's an example for :class:`IfElse`:
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L1071
Step 4: Write tests
.. code:: python
===================
Test that your registered Op is working correctly by adding a test to
def ifelse(cond, *args, n_outs=n_outs):
the ``test_jax.py`` test suite. The test should ensure that Aesara Op,
res = jax.lax.cond(
when included as part of a function graph, passes the tests in
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
``compare_jax_and_py`` test method. What this test method does is
)
compile the same function graph in Python and JAX and check that the
return res if n_outs > 1 else res[0]
numerical output is similar between the JAX and Python output, as well
object types to ensure correct compilation.
https://github.com/aesara-devs/aesara/blob/main/tests/link/test_jax.py
.. code:: python
Step 3: Register the function with the `jax_funcify` dispatcher
---------------------------------------------------------------
def test_jax_eye():
With the Aesara `Op` replicated in JAX, we’ll need to register the
"""Tests jaxification of the Eye operator"""
function with the Aesara JAX `Linker`. This is done through the use of
out = aet.eye(3) # Initialize an Aesara Op
`singledispatch`. If you don't know how `singledispatch` works, see the
out_fg = aesara.graph.fg.FunctionGraph([], [out]) # Create an Aesara FunctionGraph
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.
compare_jax_and_py(out_fg, []) # Pas the graph and any inputs to testing function
The relevant dispatch functions created by `singledispatch` are :func:`aesara.link.numba.dispatch.numba_funcify` and
:func:`aesara.link.jax.dispatch.jax_funcify`.
*Code in context:*
Here’s an example for the `Eye`\ `Op`:
https://github.com/aesara-devs/aesara/blob/056fcee1434818d0aed9234e01c754ed88d0f27a/tests/link/test_jax.py#L250
Step 5: Wait for CI pass and Code Review
.. code:: python
========================================
Create a pull request and ensure CI passes. If it does wait for a code
import jax.numpy as jnp
review and a likely merge!
https://github.com/aesara-devs/aesara/pulls
from aesara.tensor.basic import Eye
from aesara.link.jax.dispatch import jax_funcify
Appendix: What does singledispatcher do?
========================================
In short a dispatcher figures out what “the right thing” is to do based
@jax_funcify.register(Eye)
on the type of the first argument to the function. It’s easiest
def jax_funcify_Eye(op):
explained with an example. One is provided below in addition to the
python docs.
https://docs.python.org/3/library/functools.html#functools.singledispatch
# Obtain necessary "static" attributes from the Op being converted
dtype = op.dtype
.. code:: ipython3
# Create a JAX jit-able function that implements the Op
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
from functools import singledispatch
return eye
class Cow:
pass
cow = Cow()
class Dog:
Step 4: Write tests
pass
-------------------
dog = Dog()
@singledispatch
Test that your registered `Op` is working correctly by adding tests to the
def greeting(animal):
appropriate test suites in Aesara (e.g. in ``tests.link.test_jax`` and one of
print("This animal has not been registered")
the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
Check the existing tests for the general outline of these kinds of tests. In
most cases, a helper function can be used to easily verify the correspondence
between a JAX/Numba implementation and its `Op`.
@greeting.register(Cow)
For example, the :func:`compare_jax_and_py` function streamlines the steps
def cow_greeting(animal):
involved in making comparisons with `Op.perform`.
print("Mooooo")
@greeting.register(Dog)
Here's a small example of a test for :class:`Eye`:
def dog_greeting(animal):
print("Woof")
.. code:: python
greeting(cow)
import aesara.tensor as aet
greeting(dog)
greeting("A string object")
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
.. parsed-literal::
# Create a symbolic input for `Eye`
x_at = aet.scalar()
Mooooo
# Create a variable that is the output of an `Eye` `Op`
Woof
eye_var = aet.eye(x_at)
Animal has not been registered
# Create an Aesara `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])
This is what allows the JAX Linker to determine which the correct
# Pass the graph and any inputs to the testing function
JAXification Op is as we’ve registered it with the Aesara Op
compare_jax_and_py(out_fg, [3])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论