Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d647578c
提交
d647578c
authored
10月 31, 2022
作者:
Rémi Louf
提交者:
Thomas Wiecki
12月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor JAX implementations of `RandomVariable`
上级
054ad0fd
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
185 行增加
和
19 行删除
+185
-19
random.py
pytensor/link/jax/dispatch/random.py
+185
-19
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
d647578c
from
functools
import
singledispatch
import
jax
import
jax
import
jax.numpy
as
jnp
from
numpy.random
import
Generator
,
RandomState
from
numpy.random
import
Generator
,
RandomState
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
_coerce_to_uint32_array
,
_coerce_to_uint32_array
,
)
)
import
pytensor.tensor.random.basic
as
aer
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.tensor.
random.op
import
RandomVariabl
e
from
pytensor.tensor.
shape
import
Shap
e
numpy_bit_gens
=
{
"MT19937"
:
0
,
"PCG64"
:
1
,
"Philox"
:
2
,
"SFC64"
:
3
}
numpy_bit_gens
=
{
"MT19937"
:
0
,
"PCG64"
:
1
,
"Philox"
:
2
,
"SFC64"
:
3
}
...
@@ -46,25 +48,189 @@ def jax_typify_Generator(rng, **kwargs):
...
@@ -46,25 +48,189 @@ def jax_typify_Generator(rng, **kwargs):
return
state
return
state
@jax_funcify.register
(
RandomVariable
)
@jax_funcify.register
(
aer
.
RandomVariable
)
def
jax_funcify_RandomVariable
(
op
,
node
,
**
kwargs
):
def
jax_funcify_RandomVariable
(
op
,
node
,
**
kwargs
):
"""JAX implementation of random variables."""
rv
=
node
.
outputs
[
1
]
out_dtype
=
rv
.
type
.
dtype
out_size
=
rv
.
type
.
shape
if
isinstance
(
op
,
aer
.
MvNormalRV
):
# PyTensor sets the `size` to the concatenation of the support shape
# and the batch shape, while JAX explicitly requires the batch
# shape only for the multivariate normal.
out_size
=
node
.
outputs
[
1
]
.
type
.
shape
[:
-
1
]
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if
None
in
out_size
:
if
not
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Shape
):
raise
NotImplementedError
(
"""JAX random variables require concrete values for the `size` parameter of the distributions.
Concrete values are either constants, or the shape of an array.
"""
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
return
jax_sample_fn
(
op
)(
rng
,
size
,
out_dtype
,
*
parameters
)
else
:
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
return
jax_sample_fn
(
op
)(
rng
,
out_size
,
out_dtype
,
*
parameters
)
return
sample_fn
@singledispatch
def
jax_sample_fn
(
op
):
name
=
op
.
name
raise
NotImplementedError
(
f
"No JAX implementation for the given distribution: {name}"
)
@jax_sample_fn.register
(
aer
.
BetaRV
)
@jax_sample_fn.register
(
aer
.
DirichletRV
)
@jax_sample_fn.register
(
aer
.
PoissonRV
)
@jax_sample_fn.register
(
aer
.
MvNormalRV
)
def
jax_sample_fn_generic
(
op
):
"""Generic JAX implementation of random variables."""
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
sample
=
jax_op
(
rng_key
,
*
parameters
,
shape
=
size
,
dtype
=
dtype
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
CauchyRV
)
@jax_sample_fn.register
(
aer
.
LaplaceRV
)
@jax_sample_fn.register
(
aer
.
LogisticRV
)
@jax_sample_fn.register
(
aer
.
NormalRV
)
def
jax_sample_fn_loc_scale
(
op
):
"""JAX implementation of random variables in the loc-scale families.
JAX only implements the standard version of random variables in the
loc-scale family. We thus need to translate and rescale the results
manually.
"""
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
loc
,
scale
=
parameters
sample
=
loc
+
jax_op
(
rng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
BernoulliRV
)
@jax_sample_fn.register
(
aer
.
CategoricalRV
)
def
jax_sample_fn_no_dtype
(
op
):
"""Generic JAX implementation of random variables."""
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
sample
=
jax_op
(
rng_key
,
*
parameters
,
shape
=
size
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
RandIntRV
)
@jax_sample_fn.register
(
aer
.
UniformRV
)
def
jax_sample_fn_uniform
(
op
):
"""JAX implementation of random variables with uniform density.
We need to pass the arguments as keyword arguments since the order
of arguments is not the same.
"""
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
minval
,
maxval
=
parameters
sample
=
jax_op
(
rng_key
,
shape
=
size
,
dtype
=
dtype
,
minval
=
minval
,
maxval
=
maxval
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
ParetoRV
)
@jax_sample_fn.register
(
aer
.
GammaRV
)
def
jax_sample_fn_shape_rate
(
op
):
"""JAX implementation of random variables in the shape-rate family.
JAX only implements the standard version of random variables in the
shape-rate family. We thus need to rescale the results manually.
"""
name
=
op
.
name
name
=
op
.
name
jax_op
=
getattr
(
jax
.
random
,
name
)
# TODO Make sure there's a 1-to-1 correspondance with names
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
if
not
hasattr
(
jax
.
random
,
name
):
rng_key
=
rng
[
"jax_state"
]
raise
NotImplementedError
(
(
shape
,
rate
)
=
parameters
f
"No JAX conversion for the given distribution: {name}"
sample
=
jax_op
(
rng_key
,
shape
,
size
,
dtype
)
/
rate
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
dtype
=
node
.
outputs
[
1
]
.
dtype
return
sample_fn
def
random_variable
(
rng
,
size
,
dtype_num
,
*
args
):
if
not
op
.
inplace
:
rng
=
rng
.
copy
()
@jax_sample_fn.register
(
aer
.
ExponentialRV
)
prng
=
rng
[
"jax_state"
]
def
jax_sample_fn_exponential
(
op
):
data
=
getattr
(
jax
.
random
,
name
)(
key
=
prng
,
shape
=
size
)
"""JAX implementation of `ExponentialRV`."""
smpl_value
=
jnp
.
array
(
data
,
dtype
=
dtype
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
prng
,
num
=
1
)[
0
]
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
(
scale
,)
=
parameters
sample
=
jax
.
random
.
exponential
(
rng_key
,
size
,
dtype
)
*
scale
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
ChoiceRV
)
def
jax_funcify_choice
(
op
):
"""JAX implementation of `ChoiceRV`."""
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
(
a
,
p
,
replace
)
=
parameters
smpl_value
=
jax
.
random
.
choice
(
rng_key
,
a
,
size
,
replace
,
p
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
smpl_value
)
return
(
rng
,
smpl_value
)
return
random_variable
return
sample_fn
@jax_sample_fn.register
(
aer
.
PermutationRV
)
def
jax_sample_fn_permutation
(
op
):
"""JAX implementation of `PermutationRV`."""
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
(
x
,)
=
parameters
sample
=
jax
.
random
.
permutation
(
rng_key
,
x
)
rng
[
"jax_state"
]
=
jax
.
random
.
split
(
rng_key
,
num
=
1
)[
0
]
return
(
rng
,
sample
)
return
sample_fn
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论