Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c3bc2cc8
提交
c3bc2cc8
authored
2月 14, 2021
作者:
kc611
提交者:
Thomas Wiecki
3月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement JAX conversions for RandomVariables and RandomState types
上级
8e0b1560
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
81 行增加
和
15 行删除
+81
-15
jax_dispatch.py
aesara/link/jax/jax_dispatch.py
+54
-10
jax_linker.py
aesara/link/jax/jax_linker.py
+14
-2
test_jax.py
tests/link/test_jax.py
+13
-3
没有找到文件。
aesara/link/jax/jax_dispatch.py
浏览文件 @
c3bc2cc8
...
...
@@ -5,6 +5,8 @@ from warnings import warn
import
jax
import
jax.numpy
as
jnp
import
jax.scipy
as
jsp
import
numpy
as
np
from
numpy.random
import
RandomState
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.configdefaults
import
config
...
...
@@ -52,6 +54,7 @@ from aesara.tensor.nlinalg import (
)
from
aesara.tensor.nnet.basic
import
Softmax
from
aesara.tensor.nnet.sigm
import
ScalarSoftplus
from
aesara.tensor.random.op
import
RandomVariable
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
aesara.tensor.slinalg
import
Cholesky
,
Solve
from
aesara.tensor.subtensor
import
(
# This is essentially `np.take`; Boolean mask indexing and setting
...
...
@@ -66,6 +69,10 @@ from aesara.tensor.subtensor import ( # This is essentially `np.take`; Boolean
from
aesara.tensor.type_other
import
MakeSlice
# For use with JAX since JAX doesn't support 'str' arguments
numpy_bit_gens
=
{
"MT19937"
:
0
,
"PCG64"
:
1
,
"Philox"
:
2
,
"SFC64"
:
3
}
if
config
.
floatX
==
"float64"
:
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
else
:
...
...
@@ -125,21 +132,18 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
def
jax_inputs_func
(
*
inputs
,
i_dtype
=
i_dtype
,
idx
=
idx
):
return
j
np
.
array
(
inputs
[
idx
],
dtype
=
jnp
.
dtype
(
i_dtype
)
)
return
j
ax_typify
(
inputs
[
idx
],
i_dtype
)
input_f
=
jax_inputs_func
elif
i
.
owner
is
None
:
# This input is something like a `aesara.graph.basic.Constant`
# This input is something like a
n
`aesara.graph.basic.Constant`
i_dtype
=
getattr
(
i
,
"dtype"
,
None
)
i_data
=
i
.
data
def
jax_data_func
(
*
inputs
,
i_dtype
=
i_dtype
,
i_data
=
i_data
):
if
i_dtype
is
None
:
return
i_data
else
:
return
jnp
.
array
(
i_data
,
dtype
=
jnp
.
dtype
(
i_dtype
))
return
jax_typify
(
i_data
,
i_dtype
)
input_f
=
jax_data_func
else
:
...
...
@@ -171,7 +175,6 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def
jax_func
(
*
inputs
):
func_args
=
[
fn
(
*
inputs
)
for
fn
in
input_funcs
]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return
return_func
(
*
func_args
)
jax_funcs
.
append
(
update_wrapper
(
jax_func
,
return_func
))
...
...
@@ -184,9 +187,31 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
return
jax_funcs
@singledispatch
def
jax_typify
(
data
,
dtype
):
"""Convert instances of Aesara `Type`s to JAX types."""
if
dtype
is
None
:
return
data
if
dtype
is
not
None
:
return
jnp
.
array
(
data
,
dtype
=
dtype
)
raise
NotImplementedError
(
f
"No JAX conversion for data and dtype: {data}, {dtype}"
)
@jax_typify.register
(
np
.
ndarray
)
def
jax_typify_ndarray
(
data
,
dtype
):
return
jnp
.
array
(
data
,
dtype
=
dtype
)
@jax_typify.register
(
RandomState
)
def
jax_typify_RandomState
(
state
,
dtype
):
state
=
state
.
get_state
(
legacy
=
False
)
state
[
"bit_generator"
]
=
numpy_bit_gens
[
state
[
"bit_generator"
]]
return
state
@singledispatch
def
jax_funcify
(
op
):
"""Create a JAX
"perform" function for an Aesara `Variable` and its
`Op`."""
"""Create a JAX
compatible function from an Aesara
`Op`."""
raise
NotImplementedError
(
f
"No JAX conversion for the given `Op`: {op}"
)
...
...
@@ -617,8 +642,6 @@ def jax_funcify_Subtensor(op):
else
:
cdata
=
ilists
# breakpoint()
if
len
(
cdata
)
==
1
:
cdata
=
cdata
[
0
]
...
...
@@ -1082,3 +1105,24 @@ def jax_funcify_BatchedDot(op):
return
jnp
.
einsum
(
"nij,njk->nik"
,
a
,
b
)
return
batched_dot
@jax_funcify.register
(
RandomVariable
)
def
jax_funcify_RandomVariable
(
op
):
name
=
op
.
name
if
not
hasattr
(
jax
.
random
,
name
):
raise
NotImplementedError
(
f
"No JAX conversion for the given distribution: {name}"
)
def
random_variable
(
rng
,
size
,
dtype
,
*
args
):
prng
=
jax
.
random
.
PRNGKey
(
rng
[
"state"
][
"key"
][
0
])
dtype
=
jnp
.
dtype
(
dtype
)
data
=
getattr
(
jax
.
random
,
name
)(
key
=
prng
,
shape
=
size
)
smpl_value
=
jnp
.
array
(
data
,
dtype
=
dtype
)
prng
=
jax
.
random
.
split
(
prng
,
num
=
1
)[
0
]
jax
.
ops
.
index_update
(
rng
[
"state"
][
"key"
],
0
,
prng
[
0
])
return
(
rng
,
smpl_value
)
return
random_variable
aesara/link/jax/jax_linker.py
浏览文件 @
c3bc2cc8
from
collections.abc
import
Sequence
from
warnings
import
warn
from
numpy.random
import
RandomState
from
aesara.graph.basic
import
Constant
from
aesara.link.basic
import
Container
,
PerformLinker
from
aesara.link.utils
import
gc_helper
,
map_storage
,
streamline
...
...
@@ -44,7 +46,7 @@ class JAXLinker(PerformLinker):
"""
import
jax
from
aesara.link.jax.jax_dispatch
import
jax_funcify
from
aesara.link.jax.jax_dispatch
import
jax_funcify
,
jax_typify
output_nodes
=
[
o
.
owner
for
o
in
self
.
fgraph
.
outputs
]
...
...
@@ -59,7 +61,17 @@ class JAXLinker(PerformLinker):
n
for
n
,
i
in
enumerate
(
self
.
fgraph
.
inputs
)
if
isinstance
(
i
,
Constant
)
]
thunk_inputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
inputs
]
thunk_inputs
=
[]
for
n
in
self
.
fgraph
.
inputs
:
sinput
=
storage_map
[
n
]
if
isinstance
(
sinput
[
0
],
RandomState
):
new_value
=
jax_typify
(
sinput
[
0
],
getattr
(
sinput
[
0
],
"dtype"
,
None
))
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-JAXified graphs will have problems.
sinput
=
[
new_value
]
thunk_inputs
.
append
(
sinput
)
thunks
=
[]
...
...
tests/link/test_jax.py
浏览文件 @
c3bc2cc8
...
...
@@ -7,7 +7,7 @@ import aesara.scalar.basic as aes
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
Mode
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.compile.sharedvalue
import
shared
from
aesara.compile.sharedvalue
import
SharedVariable
,
shared
from
aesara.configdefaults
import
config
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
get_test_value
...
...
@@ -29,6 +29,7 @@ from aesara.tensor.math import clip, cosh, gammaln, log
from
aesara.tensor.math
import
max
as
aet_max
from
aesara.tensor.math
import
maximum
,
prod
from
aesara.tensor.math
import
sum
as
aet_sum
from
aesara.tensor.random.basic
import
normal
from
aesara.tensor.shape
import
Shape
,
Shape_i
,
SpecifyShape
,
reshape
from
aesara.tensor.type
import
(
dscalar
,
...
...
@@ -90,7 +91,8 @@ def compare_jax_and_py(
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
aesara_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
fn_inputs
=
[
i
for
i
in
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
)]
aesara_jax_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
jax_res
=
aesara_jax_fn
(
*
inputs
)
if
must_be_device_array
:
...
...
@@ -101,7 +103,7 @@ def compare_jax_and_py(
else
:
assert
isinstance
(
jax_res
,
jax
.
interpreters
.
xla
.
DeviceArray
)
aesara_py_fn
=
function
(
f
graph
.
inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
aesara_py_fn
=
function
(
f
n_
inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
py_res
=
aesara_py_fn
(
*
inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
...
...
@@ -965,3 +967,11 @@ def test_extra_ops():
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
@pytest.mark.xfail
(
reason
=
"The RNG states are not 1:1"
,
raises
=
AssertionError
)
def
test_random
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
out
=
normal
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
compare_jax_and_py
(
fgraph
,
[])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论