Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6d8ba993
提交
6d8ba993
authored
12月 11, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Generalize and update the JAX Op conversion docs
上级
34375f41
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
107 行增加
和
109 行删除
+107
-109
jax_op.rst
doc/extending/jax_op.rst
+107
-109
没有找到文件。
doc/extending/jax_op.rst
浏览文件 @
6d8ba993
Tutorial on adding JAX Ops to Aesara
====================================
Adding JAX and Numba support for `Op`\s
====================================
===
Aesara is able to convert its graphs into JAX compiled functions. In order to do
this, each ``Op`` in the graph must have a JAX implementation. This tutorial
will explain how JAX implementations are created for an ``Op``.
Aesara is able to convert its graphs into JAX and Numba compiled functions. In order to do
this, each :class:`Op` in an Aesara graph must have an equivalent JAX/Numba implementation function.
Step 1: Identify the Aesara Op you’d like to JAXify
===================================================
This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
Determine which Aesara Op you’d like supported with JAX and identify the
function signature and return values. This will come in handy as we need
to know what we want JAX to do.
Step 1: Identify the Aesara :class:`Op` you’d like to implement in JAX
----------------------------------------------------------------------
Find the source for the Aesara :class:`Op` you’d like to be supported in JAX, and
identify the function signature and return values. These can be determined by
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
with Aesara :class:`Op`\s in order to provide a conversion implementation, so first read
:ref:`extending_aesara` if you are not familiar.
For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows:
.. code:: python
def make_node(self, n, m, k):
n = as_tensor_variable(n)
m = as_tensor_variable(m)
k = as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, broadcastable=(False, False))()],
)
The :class:`Apply` instance that's returned specifies the exact types of inputs that
our JAX implementation will receive and the exact types of outputs it's expected to
return--both in terms of their data types and number of dimensions.
The actual inputs our implementation will receive are necessarily numeric values
or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the
general signature of the underlying computation.
More specifically, the :class:`Apply` implies that the inputs come from values that are
automatically converted to Aesara variables via :func:`as_tensor_variable`, and
the ``assert``\s that follow imply that they must be scalars. According to this
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
implementation must be able to handle all the possible data types.
It also tells us that there's only one return value, that it has a data type
determined by :attr:`Eye.dtype`, and that it has two non-broadcastable
dimensions. The latter implies that the result is necessarily a matrix. The
former implies that our JAX implementation will need to access the :attr:`dtype`
attribute of the Aesara :class:`Eye`\ :class:`Op` it's converting.
Next, we can look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented in JAX.
| Here are the examples for ``eye`` and ``ifelse`` from Aesara from the
compiled doc and codebase respectively
| https://aesara.readthedocs.io/en/latest/library/tensor/basic.html?highlight=eye#aesara.tensor.eye
| https://github.com/aesara-devs/aesara/blob/main/aesara/ifelse.py#L35
Step 2: Find the relevant JAX method (or something close)
=========================================================
---------------------------------------------------------
With a precise idea of what the Aesara Op does we need to figure out how
to implement it in JAX. In easiest scenario JAX has a similarly named
method that does the same thing. For example with the ``eye`` operator
we find the paired ``jax.numpy.eye`` method.
With a precise idea of what the Aesara :class:`Op` does we need to figure out how
to implement it in JAX. In the best case scenario, JAX has a similarly named
function that performs exactly the same computations as the :class:`Op`. For
example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye`
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_).
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye
If we wanted to implement an :class:`Op` like :class:`IfElse`, we might need to
recreate the functionality with some custom logic. In many cases, at least some
custom logic is needed to reformat the inputs and outputs so that they exactly
match the `Op`'s.
For ifelse we’ll need to recreate the functionality with some custom
logic.
Here's an example for :class:`IfElse`:
.. code:: python
...
...
@@ -38,118 +82,72 @@ logic.
)
return res if n_outs > 1 else res[0]
*Code in context:*
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L583
Step 3: Register the function with the
jax_funcify
dispatcher
=============================================================
Step 3: Register the function with the
`jax_funcify`
dispatcher
---------------------------------------------------------------
With the Aesara
Op replicated in JAX we’ll need to now register this
function with the Aesara JAX
Linker. This is done through the dispatcher
decorator and closure as seen below. If unsure how dispatching works a
short tutorial on dispatching is at the bottom
.
With the Aesara
`Op` replicated in JAX, we’ll need to register the
function with the Aesara JAX
`Linker`. This is done through the use of
`singledispatch`. If you don't know how `singledispatch` works, see the
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_
.
The linker functions should be added to ``dispatch`` module linked
below.
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py
The relevant dispatch functions created by `singledispatch` are :func:`aesara.link.numba.dispatch.numba_funcify` and
:func:`aesara.link.jax.dispatch.jax_funcify`.
Here’s an example for the
Eye Op.
Here’s an example for the
`Eye`\ `Op`:
.. code:: python
import jax.numpy as jnp
from aesara.tensor.basic import Eye
from aesara.link.jax.dispatch import jax_funcify
@jax_funcify.register(Eye) # The decorator
def jax_funcify_Eye(op): # The function that takes an Op and returns its JAX equivalent
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
# Obtain necessary "static" attributes from the Op being converted
dtype = op.dtype
# Create a JAX jit-able function that implements the Op
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
*Code in context:*
https://github.com/aesara-devs/aesara/blob/main/aesara/link/jax/dispatch.py#L1071
Step 4: Write tests
===================
Test that your registered Op is working correctly by adding a test to
the ``test_jax.py`` test suite. The test should ensure that Aesara Op,
when included as part of a function graph, passes the tests in
``compare_jax_and_py`` test method. What this test method does is
compile the same function graph in Python and JAX and check that the
numerical output is similar between the JAX and Python output, as well
object types to ensure correct compilation.
https://github.com/aesara-devs/aesara/blob/main/tests/link/test_jax.py
.. code:: python
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = aet.eye(3) # Initialize an Aesara Op
out_fg = aesara.graph.fg.FunctionGraph([], [out]) # Create an Aesara FunctionGraph
-------------------
compare_jax_and_py(out_fg, []) # Pas the graph and any inputs to testing function
Test that your registered `Op` is working correctly by adding tests to the
appropriate test suites in Aesara (e.g. in ``tests.link.test_jax`` and one of
the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
Check the existing tests for the general outline of these kinds of tests. In
most cases, a helper function can be used to easily verify the correspondence
between a JAX/Numba implementation and its `Op`.
*Code in context:*
https://github.com/aesara-devs/aesara/blob/056fcee1434818d0aed9234e01c754ed88d0f27a/tests/link/test_jax.py#L250
For example, the :func:`compare_jax_and_py` function streamlines the steps
involved in making comparisons with `Op.perform`.
Step 5: Wait for CI pass and Code Review
========================================
Create a pull request and ensure CI passes. If it does wait for a code
review and a likely merge!
https://github.com/aesara-devs/aesara/pulls
Appendix: What does singledispatcher do?
========================================
In short a dispatcher figures out what “the right thing” is to do based
on the type of the first argument to the function. It’s easiest
explained with an example. One is provided below in addition to the
python docs.
https://docs.python.org/3/library/functools.html#functools.singledispatch
.. code:: ipython3
from functools import singledispatch
class Cow:
pass
cow = Cow()
class Dog:
pass
dog = Dog()
@singledispatch
def greeting(animal):
print("This animal has not been registered")
@greeting.register(Cow)
def cow_greeting(animal):
print("Mooooo")
@greeting.register(Dog)
def dog_greeting(animal):
print("Woof")
Here's a small example of a test for :class:`Eye`:
.. code:: python
greeting(cow)
greeting(dog)
greeting("A string object")
import aesara.tensor as aet
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
.. parsed-literal::
# Create a symbolic input for `Eye`
x_at = aet.scalar()
Mooooo
Woof
Animal has not been registered
# Create a variable that is the output of an `Eye` `Op`
eye_var = aet.eye(x_at)
# Create an Aesara `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])
This is what allows the JAX Linker to determine which the correct
JAXification Op is as we’ve registered it with the Aesara Op
# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论