Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b6266896
提交
b6266896
authored
12月 09, 2022
作者:
Adrien Corenflos
提交者:
Thomas Wiecki
12月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split RNG keys before using them in JAX backend
上级
5c63ee70
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
32 行增加
和
20 行删除
+32
-20
random.py
pytensor/link/jax/dispatch/random.py
+32
-20
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
b6266896
...
...
@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
sample
=
jax_op
(
rng_key
,
*
parameters
,
shape
=
size
,
dtype
=
dtype
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
jax_op
(
sampling_key
,
*
parameters
,
shape
=
size
,
dtype
=
dtype
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
loc
,
scale
=
parameters
sample
=
loc
+
jax_op
(
r
ng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
loc
+
jax_op
(
sampli
ng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
sample
=
jax_op
(
rng_key
,
*
parameters
,
shape
=
size
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
jax_op
(
sampling_key
,
*
parameters
,
shape
=
size
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
minval
,
maxval
=
parameters
sample
=
jax_op
(
rng_key
,
shape
=
size
,
dtype
=
dtype
,
minval
=
minval
,
maxval
=
maxval
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
jax_op
(
sampling_key
,
shape
=
size
,
dtype
=
dtype
,
minval
=
minval
,
maxval
=
maxval
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
shape
,
rate
)
=
parameters
sample
=
jax_op
(
r
ng_key
,
shape
,
size
,
dtype
)
/
rate
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
jax_op
(
sampli
ng_key
,
shape
,
size
,
dtype
)
/
rate
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
scale
,)
=
parameters
sample
=
jax
.
random
.
exponential
(
r
ng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
jax
.
random
.
exponential
(
sampli
ng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -239,13 +247,14 @@ def jax_sample_fn_t(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
df
,
loc
,
scale
,
)
=
parameters
sample
=
loc
+
jax
.
random
.
t
(
r
ng_key
,
df
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
loc
+
jax
.
random
.
t
(
sampli
ng_key
,
df
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -257,9 +266,10 @@ def jax_funcify_choice(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
a
,
p
,
replace
)
=
parameters
smpl_value
=
jax
.
random
.
choice
(
r
ng_key
,
a
,
size
,
replace
,
p
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
smpl_value
=
jax
.
random
.
choice
(
sampli
ng_key
,
a
,
size
,
replace
,
p
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
smpl_value
)
return
sample_fn
...
...
@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
(
x
,)
=
parameters
sample
=
jax
.
random
.
permutation
(
r
ng_key
,
x
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
sample
=
jax
.
random
.
permutation
(
sampli
ng_key
,
x
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
...
...
@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
loc
,
scale
=
parameters
sample
=
loc
+
jax
.
random
.
normal
(
r
ng_key
,
size
,
dtype
)
*
scale
sample
=
loc
+
jax
.
random
.
normal
(
sampli
ng_key
,
size
,
dtype
)
*
scale
sample_exp
=
jax
.
numpy
.
exp
(
sample
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample_exp
)
return
sample_fn
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论