Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
90da9e61
提交
90da9e61
authored
5月 13, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 13, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix bug in JAX implementation of RandomVariables with implicit size
上级
17a5e424
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
48 行增加
和
14 行删除
+48
-14
random.py
pytensor/link/jax/dispatch/random.py
+20
-14
test_random.py
tests/link/jax/test_random.py
+28
-0
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
90da9e61
...
...
@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
# PyTensor uses empty size to represent size = None
if
jax
.
numpy
.
asarray
(
size
)
.
shape
==
(
0
,):
size
=
None
return
jax_sample_fn
(
op
)(
rng
,
size
,
out_dtype
,
*
parameters
)
else
:
...
...
@@ -161,6 +164,8 @@ def jax_sample_fn_loc_scale(op):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
loc
,
scale
=
parameters
if
size
is
None
:
size
=
jax
.
numpy
.
broadcast_arrays
(
loc
,
scale
)[
0
]
.
shape
sample
=
loc
+
jax_op
(
sampling_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
...
...
@@ -184,15 +189,16 @@ def jax_sample_fn_bernoulli(op):
@jax_sample_fn.register
(
ptr
.
CategoricalRV
)
def
jax_sample_fn_no_dtype
(
op
):
"""Generic JAX implementation of random variables."""
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
def
jax_sample_fn_categorical
(
op
):
"""JAX implementation of `CategoricalRV`."""
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
# We need a separate dispatch because Categorical expects logits in JAX
def
sample_fn
(
rng
,
size
,
dtype
,
p
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
jax_op
(
sampling_key
,
*
parameters
,
shape
=
size
)
logits
=
jax
.
scipy
.
special
.
logit
(
p
)
sample
=
jax
.
random
.
categorical
(
sampling_key
,
logits
=
logits
,
shape
=
size
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
...
...
@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
def
sample_fn
(
rng
,
size
,
dtype
,
shape
,
scale
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
if
size
is
None
:
size
=
jax
.
numpy
.
broadcast_arrays
(
shape
,
scale
)[
0
]
.
shape
sample
=
jax_op
(
sampling_key
,
shape
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
...
...
@@ -254,10 +262,11 @@ def jax_sample_fn_shape_scale(op):
def
jax_sample_fn_exponential
(
op
):
"""JAX implementation of `ExponentialRV`."""
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
dtype
,
scale
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
scale
,)
=
parameters
if
size
is
None
:
size
=
jax
.
numpy
.
asarray
(
scale
)
.
shape
sample
=
jax
.
random
.
exponential
(
sampling_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
...
...
@@ -269,14 +278,11 @@ def jax_sample_fn_exponential(op):
def
jax_sample_fn_t
(
op
):
"""JAX implementation of `StudentTRV`."""
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
dtype
,
df
,
loc
,
scale
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
df
,
loc
,
scale
,
)
=
parameters
if
size
is
None
:
size
=
jax
.
numpy
.
broadcast_arrays
(
df
,
loc
,
scale
)[
0
]
.
shape
sample
=
loc
+
jax
.
random
.
t
(
sampling_key
,
df
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
...
...
tests/link/jax/test_random.py
浏览文件 @
90da9e61
...
...
@@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
assert
test_res
.
pvalue
>
0.01
@pytest.mark.parametrize
(
"rv_fn"
,
[
lambda
param_that_implies_size
:
ptr
.
normal
(
loc
=
0
,
scale
=
pt
.
exp
(
param_that_implies_size
)
),
lambda
param_that_implies_size
:
ptr
.
exponential
(
scale
=
pt
.
exp
(
param_that_implies_size
)
),
lambda
param_that_implies_size
:
ptr
.
gamma
(
shape
=
1
,
scale
=
pt
.
exp
(
param_that_implies_size
)
),
lambda
param_that_implies_size
:
ptr
.
t
(
df
=
3
,
loc
=
param_that_implies_size
,
scale
=
1
),
],
)
def
test_size_implied_by_broadcasted_parameters
(
rv_fn
):
# We need a parameter with untyped shapes to test broadcasting does not result in identical draws
param_that_implies_size
=
pt
.
matrix
(
"param_that_implies_size"
,
shape
=
(
None
,
None
))
rv
=
rv_fn
(
param_that_implies_size
)
draws
=
rv
.
eval
({
param_that_implies_size
:
np
.
zeros
((
2
,
2
))},
mode
=
jax_mode
)
assert
draws
.
shape
==
(
2
,
2
)
assert
np
.
unique
(
draws
)
.
size
==
4
@pytest.mark.parametrize
(
"size"
,
[(),
(
4
,)])
def
test_random_bernoulli
(
size
):
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论