Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0d1f65f8
提交
0d1f65f8
authored
12月 05, 2022
作者:
Rémi Louf
提交者:
Thomas Wiecki
12月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Raise when the `RandomVariable` will not compile
上级
a110e82b
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
58 行增加
和
17 行删除
+58
-17
random.py
pytensor/link/jax/dispatch/random.py
+29
-7
test_random.py
tests/link/jax/test_random.py
+29
-10
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
0d1f65f8
...
...
@@ -8,12 +8,39 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
import
pytensor.tensor.random.basic
as
aer
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.tensor.shape
import
Shape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
numpy_bit_gens
=
{
"MT19937"
:
0
,
"PCG64"
:
1
,
"Philox"
:
2
,
"SFC64"
:
3
}
SIZE_NOT_COMPATIBLE
=
"""JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants:
>>> import pytensor.tensor as at
>>> x_rv = at.random.normal(0, 1, size=(3, 2))
or the shape of an array:
>>> m = at.matrix()
>>> x_rv = at.random.normal(0, 1, size=m.shape)
"""
def
assert_size_argument_jax_compatible
(
node
):
"""Assert whether the current node can be compiled.
JAX can JIT-compile `jax.random` functions when the `size` argument
is a concrete value, i.e. either a constant or the shape of any
traced value.
"""
size
=
node
.
inputs
[
1
]
size_op
=
size
.
owner
.
op
if
not
isinstance
(
size_op
,
(
Shape
,
Shape_i
)):
raise
NotImplementedError
(
SIZE_NOT_COMPATIBLE
)
@jax_typify.register
(
RandomState
)
def
jax_typify_RandomState
(
state
,
**
kwargs
):
state
=
state
.
get_state
(
legacy
=
False
)
...
...
@@ -65,12 +92,7 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if
None
in
out_size
:
if
not
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Shape
):
raise
NotImplementedError
(
"""JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants, or the shape of an array.
"""
)
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
return
jax_sample_fn
(
op
)(
rng
,
size
,
out_dtype
,
*
parameters
)
...
...
tests/link/jax/test_random.py
浏览文件 @
0d1f65f8
...
...
@@ -449,14 +449,33 @@ def test_random_concrete_shape():
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
f
=
at
.
random
.
normal
(
0
,
1
,
size
=
(
3
,),
rng
=
rng
)
g
=
at
.
random
.
normal
(
f
,
1
,
size
=
x_at
.
shape
,
rng
=
rng
)
g_fn
=
function
([
x_at
],
g
,
mode
=
jax_mode
)
_
=
g_fn
(
np
.
ones
((
2
,
3
)))
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
x_at
.
shape
,
rng
=
rng
)
jax_fn
=
function
([
x_at
],
out
,
mode
=
jax_mode
)
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
2
,
3
)
# This should compile, and `size_at` be passed to the list of `static_argnums`.
with
pytest
.
raises
(
NotImplementedError
):
size_at
=
at
.
scalar
()
g
=
at
.
random
.
normal
(
f
,
1
,
size
=
size_at
,
rng
=
rng
)
g_fn
=
function
([
size_at
],
g
,
mode
=
jax_mode
)
_
=
g_fn
(
10
)
@pytest.mark.xfail
(
reason
=
"size argument specified as a tuple is a `DimShuffle` node"
)
def
test_random_concrete_shape_subtensor
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
x_at
.
shape
[
1
],
rng
=
rng
)
jax_fn
=
function
([
x_at
],
out
,
mode
=
jax_mode
)
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
3
,)
@pytest.mark.xfail
(
reason
=
"size argument specified as a tuple is a `MakeVector` node"
)
def
test_random_concrete_shape_subtensor_tuple
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
(
x_at
.
shape
[
0
],),
rng
=
rng
)
jax_fn
=
function
([
x_at
],
out
,
mode
=
jax_mode
)
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
2
,)
@pytest.mark.xfail
(
reason
=
"`size_at` should be specified as a static argument"
)
def
test_random_concrete_shape_graph_input
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
size_at
=
at
.
scalar
()
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
size_at
,
rng
=
rng
)
jax_fn
=
function
([
size_at
],
out
,
mode
=
jax_mode
)
assert
jax_fn
(
10
)
.
shape
==
(
10
,)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论