Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c9333bcf
提交
c9333bcf
authored
4月 12, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Create flat functions via AST for JAXification of FunctionGraphs
上级
334c86fb
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
237 行增加
和
153 行删除
+237
-153
jax_dispatch.py
aesara/link/jax/jax_dispatch.py
+107
-104
jax_linker.py
aesara/link/jax/jax_linker.py
+38
-48
test_jax.py
tests/link/test_jax.py
+92
-1
没有找到文件。
aesara/link/jax/jax_dispatch.py
浏览文件 @
c9333bcf
import
ast
import
re
import
warnings
from
collections.abc
import
Sequence
from
functools
import
reduce
,
singledispatch
,
update_wrapper
from
collections
import
Counter
from
functools
import
reduce
,
singledispatch
from
keyword
import
iskeyword
from
tempfile
import
NamedTemporaryFile
from
textwrap
import
indent
from
types
import
FunctionType
from
warnings
import
warn
import
jax
...
...
@@ -11,8 +17,10 @@ from numpy.random import RandomState
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.ifelse
import
IfElse
from
aesara.link.utils
import
map_storage
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scan.op
import
Scan
from
aesara.scan.utils
import
scan_args
as
ScanArgs
...
...
@@ -95,102 +103,6 @@ subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops
=
(
IncSubtensor
,
AdvancedIncSubtensor1
)
def
compose_jax_funcs
(
out_node
,
fgraph_inputs
,
memo
=
None
):
"""Compose JAX implementations of node operations.
This function walks the graph given by the `Apply` node, `out_node`, and
creates JAX JIT-able functions for its input and output variables.
Parameters
----------
out_node: aesara.graph.basic.Apply
The node for which we want to construct a JAX JIT-able function.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if
memo
is
None
:
memo
=
{}
if
out_node
in
memo
:
return
memo
[
out_node
]
jax_return_func
=
jax_funcify
(
out_node
.
op
)
# We create a list of JAX-able functions that produce the values of each
# input variable for `out_node`.
input_funcs
=
[]
for
i
in
out_node
.
inputs
:
if
i
in
fgraph_inputs
:
# This input is a top-level input (i.e. an input to the
# `FunctionGraph` in which this `out_node` resides)
idx
=
fgraph_inputs
.
index
(
i
)
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
def
jax_inputs_func
(
*
inputs
,
i_dtype
=
i_dtype
,
idx
=
idx
):
return
jax_typify
(
inputs
[
idx
],
i_dtype
)
input_f
=
jax_inputs_func
elif
i
.
owner
is
None
:
# This input is something like an `aesara.graph.basic.Constant`
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
i_data
=
i
.
data
def
jax_data_func
(
*
inputs
,
i_dtype
=
i_dtype
,
i_data
=
i_data
):
return
jax_typify
(
i_data
,
i_dtype
)
input_f
=
jax_data_func
else
:
# This input is the output of another node, so we need to
# generate a JAX-able function for its subgraph
input_f
=
compose_jax_funcs
(
i
.
owner
,
fgraph_inputs
,
memo
)
if
i
.
owner
.
nout
>
1
:
# This input is one of multiple outputs from the `i.owner`
# node, and we need to determine exactly which one it is and
# create a JAX-able function that returns only it.
out_idx
=
i
.
owner
.
outputs
.
index
(
i
)
(
out_fn
,)
=
input_f
def
jax_multiout_func
(
*
inputs
,
out_idx
=
out_idx
,
out_fn
=
out_fn
):
return
out_fn
(
*
inputs
)[
out_idx
]
input_f
=
jax_multiout_func
assert
callable
(
input_f
)
input_funcs
.
append
(
input_f
)
if
not
isinstance
(
jax_return_func
,
Sequence
):
jax_return_func
=
[
jax_return_func
]
jax_funcs
=
[]
for
return_func
in
jax_return_func
:
def
jax_func
(
*
inputs
):
func_args
=
[
fn
(
*
inputs
)
for
fn
in
input_funcs
]
return
return_func
(
*
func_args
)
jax_funcs
.
append
(
update_wrapper
(
jax_func
,
return_func
))
if
len
(
out_node
.
outputs
)
==
1
:
jax_funcs
=
jax_funcs
[
0
]
memo
[
out_node
]
=
jax_funcs
return
jax_funcs
@singledispatch
def
jax_typify
(
data
,
dtype
):
"""Convert instances of Aesara `Type`s to JAX types."""
...
...
@@ -213,7 +125,7 @@ def jax_typify_RandomState(state, dtype):
@singledispatch
def
jax_funcify
(
op
):
def
jax_funcify
(
op
,
**
kwargs
):
"""Create a JAX compatible function from an Aesara `Op`."""
raise
NotImplementedError
(
f
"No JAX conversion for the given `Op`: {op}"
)
...
...
@@ -458,8 +370,17 @@ def jax_funcify_Elemwise(op):
@jax_funcify.register
(
Composite
)
def
jax_funcify_Composite
(
op
):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
jax_impl
=
jax_funcify
(
op
.
fgraph
)
return
jax_impl
def
composite
(
*
args
):
return
jax_impl
(
*
args
)[
0
]
return
composite
@jax_funcify.register
(
Scan
)
...
...
@@ -684,12 +605,94 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register
(
FunctionGraph
)
def
jax_funcify_FunctionGraph
(
fgraph
):
def
jax_funcify_FunctionGraph
(
fgraph
,
order
=
None
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
if
order
is
None
:
order
=
fgraph
.
toposort
()
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
)
global_env
=
{}
fgraph_name
=
"jax_funcified_fgraph"
def
unique_name
(
x
,
names_counter
=
Counter
([
fgraph_name
]),
obj_to_names
=
{}):
if
x
in
obj_to_names
:
return
obj_to_names
[
x
]
if
isinstance
(
x
,
Variable
):
name
=
re
.
sub
(
"[^0-9a-zA-Z]+"
,
"_"
,
x
.
name
)
if
x
.
name
else
""
name
=
(
name
if
(
name
.
isidentifier
()
and
not
iskeyword
(
name
))
else
x
.
auto_name
)
elif
isinstance
(
x
,
FunctionType
):
name
=
x
.
__name__
else
:
name
=
type
(
x
)
.
__name__
name_suffix
=
names_counter
.
get
(
name
,
""
)
local_name
=
f
"{name}{name_suffix}"
names_counter
.
update
((
name
,))
obj_to_names
[
x
]
=
local_name
return
local_name
body_assigns
=
[]
for
node
in
order
:
jax_func
=
jax_funcify
(
node
.
op
)
# Create a local alias with a unique name
local_jax_func_name
=
unique_name
(
jax_func
)
global_env
[
local_jax_func_name
]
=
jax_func
node_input_names
=
[]
for
i
in
node
.
inputs
:
local_input_name
=
unique_name
(
i
)
if
storage_map
[
i
][
0
]
is
not
None
or
isinstance
(
i
,
Constant
):
# Constants need to be assigned locally and referenced
global_env
[
local_input_name
]
=
jax_typify
(
storage_map
[
i
][
0
],
None
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names
.
append
(
local_input_name
)
node_output_names
=
[
unique_name
(
v
)
for
v
in
node
.
outputs
]
body_assigns
.
append
(
f
"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
inputs
]
fgraph_output_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
outputs
]
joined_body_assigns
=
indent
(
"
\n
"
.
join
(
body_assigns
),
" "
)
if
len
(
fgraph_output_names
)
==
1
:
fgraph_return_src
=
f
"({fgraph_output_names[0]},)"
else
:
fgraph_return_src
=
", "
.
join
(
fgraph_output_names
)
fgraph_def_src
=
f
"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast
=
ast
.
parse
(
fgraph_def_src
)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with
NamedTemporaryFile
(
delete
=
False
)
as
f
:
filename
=
f
.
name
f
.
write
(
fgraph_def_src
.
encode
())
mod_code
=
compile
(
fgraph_def_ast
,
filename
,
mode
=
"exec"
)
exec
(
mod_code
,
global_env
,
locals
())
out_nodes
=
[
r
.
owner
for
r
in
fgraph
.
outputs
if
r
.
owner
is
not
None
]
jax_funcs
=
[
compose_jax_funcs
(
o
,
fgraph
.
inputs
)
for
o
in
out_nodes
]
fgraph_def
=
locals
()[
fgraph_name
]
return
jax_funcs
return
fgraph_def
@jax_funcify.register
(
CAReduce
)
...
...
aesara/link/jax/jax_linker.py
浏览文件 @
c9333bcf
from
collections.abc
import
Sequence
from
warnings
import
warn
from
numpy.random
import
RandomState
...
...
@@ -23,7 +22,9 @@ class JAXLinker(PerformLinker):
allow_non_jax
=
False
def
create_jax_thunks
(
self
,
compute_map
,
storage_map
):
def
create_jax_thunks
(
self
,
compute_map
,
order
,
input_storage
,
output_storage
,
storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
...
...
@@ -51,9 +52,12 @@ class JAXLinker(PerformLinker):
output_nodes
=
[
o
.
owner
for
o
in
self
.
fgraph
.
outputs
]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs
=
jax_funcify
(
self
.
fgraph
)
assert
len
(
jaxed_fgraph_outputs
)
==
len
(
output_nodes
)
jaxed_fgraph
=
jax_funcify
(
self
.
fgraph
,
input_storage
=
input_storage
,
output_storage
=
output_storage
,
storage_map
=
storage_map
,
)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
...
...
@@ -75,52 +79,36 @@ class JAXLinker(PerformLinker):
thunks
=
[]
for
node
,
jax_funcs
in
zip
(
output_nodes
,
jaxed_fgraph_outputs
):
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
outputs
]
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
node
.
outputs
]
fgraph_jit
=
jax
.
jit
(
jaxed_fgraph
,
static_argnums
)
if
not
isinstance
(
jax_funcs
,
Sequence
):
jax_funcs
=
[
jax_funcs
]
def
thunk
(
fgraph
=
self
.
fgraph
,
fgraph_jit
=
fgraph_jit
,
thunk_inputs
=
thunk_inputs
,
thunk_outputs
=
thunk_outputs
,
):
outputs
=
fgraph_jit
(
*
[
x
[
0
]
for
x
in
thunk_inputs
])
jax_impl_jits
=
[
jax
.
jit
(
jax_func
,
static_argnums
)
for
jax_func
in
jax_funcs
]
for
o_node
,
o_storage
,
o_val
in
zip
(
fgraph
.
outputs
,
thunk_outputs
,
outputs
):
compute_map
[
o_node
][
0
]
=
True
if
len
(
o_storage
)
>
1
:
assert
len
(
o_storage
)
==
len
(
o_val
)
for
i
,
o_sub_val
in
enumerate
(
o_val
):
o_storage
[
i
]
=
o_sub_val
else
:
o_storage
[
0
]
=
o_val
return
outputs
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunk
.
lazy
=
False
thunks
.
append
(
thunk
)
def
thunk
(
node
=
node
,
jax_impl_jits
=
jax_impl_jits
,
thunk_outputs
=
thunk_outputs
):
outputs
=
[
jax_impl_jit
(
*
[
x
[
0
]
for
x
in
thunk_inputs
])
for
jax_impl_jit
in
jax_impl_jits
]
if
len
(
jax_impl_jits
)
<
len
(
node
.
outputs
):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs
=
outputs
[
0
]
for
o_node
,
o_storage
,
o_val
in
zip
(
node
.
outputs
,
thunk_outputs
,
outputs
):
compute_map
[
o_node
][
0
]
=
True
if
len
(
o_storage
)
>
1
:
assert
len
(
o_storage
)
==
len
(
o_val
)
for
i
,
o_sub_val
in
enumerate
(
o_val
):
o_storage
[
i
]
=
o_sub_val
else
:
o_storage
[
0
]
=
o_val
return
outputs
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunk
.
lazy
=
False
thunks
.
append
(
thunk
)
return
thunks
,
output_nodes
# This is a bit hackish, but we only return one of the output nodes
return
thunks
,
output_nodes
[:
1
]
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
fgraph
=
self
.
fgraph
...
...
@@ -138,7 +126,9 @@ class JAXLinker(PerformLinker):
try
:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks
,
nodes
=
self
.
create_jax_thunks
(
compute_map
,
storage_map
)
thunks
,
nodes
=
self
.
create_jax_thunks
(
compute_map
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
except
NotImplementedError
as
e
:
if
not
self
.
allow_non_jax
:
...
...
tests/link/test_jax.py
浏览文件 @
c9333bcf
...
...
@@ -10,11 +10,13 @@ from aesara.compile.mode import Mode
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.compile.sharedvalue
import
SharedVariable
,
shared
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
get_test_value
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.optdb
import
Query
from
aesara.ifelse
import
ifelse
from
aesara.link.jax
import
JAXLinker
from
aesara.scalar.basic
import
Composite
from
aesara.scan.basic
import
scan
from
aesara.tensor
import
basic
as
aet
from
aesara.tensor
import
blas
as
aet_blas
...
...
@@ -24,6 +26,7 @@ from aesara.tensor import nlinalg as aet_nlinalg
from
aesara.tensor
import
nnet
as
aet_nnet
from
aesara.tensor
import
slinalg
as
aet_slinalg
from
aesara.tensor
import
subtensor
as
aet_subtensor
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
MaxAndArgmax
from
aesara.tensor.math
import
all
as
aet_all
from
aesara.tensor.math
import
clip
,
cosh
,
gammaln
,
log
...
...
@@ -295,6 +298,94 @@ def test_jax_basic():
)
def
test_jax_Composite
():
x_s
=
aes
.
float64
(
"x"
)
y_s
=
aes
.
float64
(
"y"
)
comp_op
=
Elemwise
(
Composite
([
x_s
,
y_s
],
[
x_s
+
y_s
*
2
]))
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
out
=
comp_op
(
x
,
y
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
,
20
)
.
astype
(
config
.
floatX
),
]
_
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
def
test_jax_FunctionGraph_names
():
import
inspect
from
aesara.link.jax.jax_dispatch
import
jax_funcify
x
=
scalar
(
"1x"
)
y
=
scalar
(
"_"
)
z
=
scalar
()
q
=
scalar
(
"def"
)
out_fg
=
FunctionGraph
([
x
,
y
,
z
,
q
],
[
x
,
y
,
z
,
q
],
clone
=
False
)
out_jx
=
jax_funcify
(
out_fg
)
sig
=
inspect
.
signature
(
out_jx
)
assert
(
x
.
auto_name
,
"_"
,
z
.
auto_name
,
q
.
auto_name
)
==
tuple
(
sig
.
parameters
.
keys
())
assert
(
1
,
2
,
3
,
4
)
==
out_jx
(
1
,
2
,
3
,
4
)
def
test_jax_FunctionGraph_once
():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from
aesara.link.jax.jax_dispatch
import
jax_funcify
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
class
TestOp
(
Op
):
def
__init__
(
self
):
self
.
called
=
0
def
make_node
(
self
,
*
args
):
return
Apply
(
self
,
list
(
args
),
[
x
.
type
()
for
x
in
args
])
def
perform
(
self
,
inputs
,
outputs
):
for
i
,
inp
in
enumerate
(
inputs
):
outputs
[
i
][
0
]
=
inp
[
0
]
@jax_funcify.register
(
TestOp
)
def
jax_funcify_TestOp
(
op
):
def
func
(
*
args
,
op
=
op
):
op
.
called
+=
1
return
list
(
args
)
return
func
op1
=
TestOp
()
op2
=
TestOp
()
q
,
r
=
op1
(
x
,
y
)
outs
=
op2
(
q
+
r
,
q
+
r
)
out_fg
=
FunctionGraph
([
x
,
y
],
outs
,
clone
=
False
)
assert
len
(
out_fg
.
outputs
)
==
2
out_jx
=
jax_funcify
(
out_fg
)
x_val
=
np
.
r_
[
1
,
2
]
.
astype
(
config
.
floatX
)
y_val
=
np
.
r_
[
2
,
3
]
.
astype
(
config
.
floatX
)
res
=
out_jx
(
x_val
,
y_val
)
assert
len
(
res
)
==
2
assert
op1
.
called
==
1
assert
op2
.
called
==
1
res
=
out_jx
(
x_val
,
y_val
)
assert
len
(
res
)
==
2
assert
op1
.
called
==
2
assert
op2
.
called
==
2
def
test_jax_eye
():
"""Tests jaxification of the Eye operator"""
out
=
aet
.
eye
(
3
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论