Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
51210c39
提交
51210c39
authored
1月 24, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extend supported RandomVariables in JAX backend via NumPyro
Dependency is optional
上级
dcd24a36
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
206 行增加
和
8 行删除
+206
-8
test.yml
.github/workflows/test.yml
+1
-1
extra_ops.py
pytensor/link/jax/dispatch/extra_ops.py
+6
-1
random.py
pytensor/link/jax/dispatch/random.py
+82
-5
jax.py
pytensor/tensor/random/rewriting/jax.py
+18
-1
test_random.py
tests/link/jax/test_random.py
+99
-0
没有找到文件。
.github/workflows/test.yml
浏览文件 @
51210c39
...
...
@@ -117,7 +117,7 @@ jobs:
run
:
|
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
numpyro
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
...
...
pytensor/link/jax/dispatch/extra_ops.py
浏览文件 @
51210c39
...
...
@@ -3,6 +3,7 @@ import warnings
import
jax.numpy
as
jnp
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.basic
import
infer_static_shape
from
pytensor.tensor.extra_ops
import
(
Bartlett
,
BroadcastTo
,
...
...
@@ -102,8 +103,12 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
@jax_funcify.register
(
BroadcastTo
)
def
jax_funcify_BroadcastTo
(
op
,
**
kwargs
):
def
jax_funcify_BroadcastTo
(
op
,
node
,
**
kwargs
):
shape
=
node
.
inputs
[
1
:]
static_shape
=
infer_static_shape
(
shape
)[
1
]
def
broadcast_to
(
x
,
*
shape
):
shape
=
tuple
(
st
if
st
is
not
None
else
s
for
s
,
st
in
zip
(
shape
,
static_shape
))
return
jnp
.
broadcast_to
(
x
,
shape
)
return
broadcast_to
...
...
pytensor/link/jax/dispatch/random.py
浏览文件 @
51210c39
from
functools
import
singledispatch
import
jax
import
numpy
as
np
from
numpy.random
import
Generator
,
RandomState
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
_coerce_to_uint32_array
,
...
...
@@ -12,6 +13,13 @@ from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from
pytensor.tensor.shape
import
Shape
,
Shape_i
try
:
import
numpyro
# noqa: F401
numpyro_available
=
True
except
ImportError
:
numpyro_available
=
False
numpy_bit_gens
=
{
"MT19937"
:
0
,
"PCG64"
:
1
,
"Philox"
:
2
,
"SFC64"
:
3
}
...
...
@@ -83,11 +91,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
out_dtype
=
rv
.
type
.
dtype
out_size
=
rv
.
type
.
shape
if
isinstance
(
op
,
aer
.
MvNormalRV
):
# PyTensor sets the `size` to the concatenation of the support shape
# and the batch shape, while JAX explicitly requires the batch
# shape only for the multivariate normal.
out_size
=
node
.
outputs
[
1
]
.
type
.
shape
[:
-
1
]
if
op
.
ndim_supp
>
0
:
out_size
=
node
.
outputs
[
1
]
.
type
.
shape
[:
-
op
.
ndim_supp
]
# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
...
...
@@ -292,3 +297,75 @@ def jax_sample_fn_permutation(op):
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
BinomialRV
)
def
jax_sample_fn_binomial
(
op
):
if
not
numpyro_available
:
raise
NotImplementedError
(
f
"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from
numpyro.distributions.util
import
binomial
def
sample_fn
(
rng
,
size
,
dtype
,
n
,
p
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
binomial
(
key
=
sampling_key
,
n
=
n
,
p
=
p
,
shape
=
size
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
MultinomialRV
)
def
jax_sample_fn_multinomial
(
op
):
if
not
numpyro_available
:
raise
NotImplementedError
(
f
"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from
numpyro.distributions.util
import
multinomial
def
sample_fn
(
rng
,
size
,
dtype
,
n
,
p
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
multinomial
(
key
=
sampling_key
,
n
=
n
,
p
=
p
,
shape
=
size
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
aer
.
VonMisesRV
)
def
jax_sample_fn_vonmises
(
op
):
if
not
numpyro_available
:
raise
NotImplementedError
(
f
"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)
from
numpyro.distributions.util
import
von_mises_centered
def
sample_fn
(
rng
,
size
,
dtype
,
mu
,
kappa
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
von_mises_centered
(
key
=
sampling_key
,
concentration
=
kappa
,
shape
=
size
,
dtype
=
dtype
)
sample
=
(
sample
+
mu
+
np
.
pi
)
%
(
2.0
*
np
.
pi
)
-
np
.
pi
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
51210c39
...
...
@@ -2,10 +2,11 @@ from pytensor.compile import optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.tensor
import
abs
as
abs_t
from
pytensor.tensor
import
exp
,
floor
,
log
,
log1p
,
reciprocal
,
sqrt
from
pytensor.tensor
import
broadcast_arrays
,
exp
,
floor
,
log
,
log1p
,
reciprocal
,
sqrt
from
pytensor.tensor.basic
import
MakeVector
,
cast
,
ones_like
,
switch
,
zeros_like
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.random.basic
import
(
BetaBinomialRV
,
ChiSquareRV
,
GenGammaRV
,
GeometricRV
,
...
...
@@ -14,6 +15,8 @@ from pytensor.tensor.random.basic import (
LogNormalRV
,
NegBinomialRV
,
WaldRV
,
beta
,
binomial
,
gamma
,
normal
,
poisson
,
...
...
@@ -133,6 +136,15 @@ def wald_from_normal_uniform(fgraph, node):
return
[
next_rng
,
cast
(
w
,
dtype
=
node
.
default_output
()
.
dtype
)]
@node_rewriter
([
BetaBinomialRV
])
def
beta_binomial_from_beta_binomial
(
fgraph
,
node
):
rng
,
*
other_inputs
,
n
,
a
,
b
=
node
.
inputs
n
,
a
,
b
=
broadcast_arrays
(
n
,
a
,
b
)
next_rng
,
b
=
beta
.
make_node
(
rng
,
*
other_inputs
,
a
,
b
)
.
outputs
next_rng
,
b
=
binomial
.
make_node
(
next_rng
,
*
other_inputs
,
n
,
b
)
.
outputs
return
[
next_rng
,
b
]
random_vars_opt
=
SequenceDB
()
random_vars_opt
.
register
(
"lognormal_from_normal"
,
...
...
@@ -174,6 +186,11 @@ random_vars_opt.register(
in2out
(
wald_from_normal_uniform
),
"jax"
,
)
random_vars_opt
.
register
(
"beta_binomial_from_beta_binomial"
,
in2out
(
beta_binomial_from_beta_binomial
),
"jax"
,
)
optdb
.
register
(
"jax_random_vars_rewrites"
,
random_vars_opt
,
"jax"
,
position
=
110
)
optdb
.
register
(
...
...
tests/link/jax/test_random.py
浏览文件 @
51210c39
...
...
@@ -19,6 +19,9 @@ from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_val
jax
=
pytest
.
importorskip
(
"jax"
)
from
pytensor.link.jax.dispatch.random
import
numpyro_available
# noqa: E402
def
test_random_RandomStream
():
"""Two successive calls of a compiled graph using `RandomStream` should
return different values.
...
...
@@ -377,6 +380,25 @@ def test_random_updates(rng_ctor):
# https://stackoverflow.com/a/48603469
lambda
mean
,
scale
:
(
mean
/
scale
,
0
,
scale
),
),
pytest
.
param
(
aer
.
vonmises
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
-
0.5
,
1.3
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dvector
(),
np
.
array
([
5.5
,
13.0
],
dtype
=
np
.
float64
),
),
],
(
2
,),
"vonmises"
,
lambda
mu
,
kappa
:
(
kappa
,
mu
),
marks
=
pytest
.
mark
.
skipif
(
not
numpyro_available
,
reason
=
"VonMises dispatch requires numpyro"
),
),
],
)
def
test_random_RandomVariable
(
rv_op
,
dist_params
,
base_size
,
cdf_name
,
params_conv
):
...
...
@@ -519,6 +541,83 @@ def test_negative_binomial():
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"Binomial dispatch requires numpyro"
)
def
test_binomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([
0.3
,
0.7
])
g
=
at
.
random
.
binomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
*
p
,
rtol
=
0.1
)
np
.
testing
.
assert_allclose
(
samples
.
std
(
axis
=
0
),
np
.
sqrt
(
n
*
p
*
(
1
-
p
)),
rtol
=
0.1
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"BetaBinomial dispatch requires numpyro"
)
def
test_beta_binomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
n
=
np
.
array
([
10
,
40
])
a
=
np
.
array
([
1.5
,
13
])
b
=
np
.
array
([
0.5
,
9
])
g
=
at
.
random
.
betabinom
(
n
,
a
,
b
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
*
a
/
(
a
+
b
),
rtol
=
0.1
)
np
.
testing
.
assert_allclose
(
samples
.
std
(
axis
=
0
),
np
.
sqrt
((
n
*
a
*
b
*
(
a
+
b
+
n
))
/
((
a
+
b
)
**
2
*
(
a
+
b
+
1
))),
rtol
=
0.1
,
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"Multinomial dispatch requires numpyro"
)
def
test_multinomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
g
=
at
.
random
.
multinomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
n
[
...
,
None
]
*
p
,
rtol
=
0.1
)
np
.
testing
.
assert_allclose
(
samples
.
std
(
axis
=
0
),
np
.
sqrt
(
n
[
...
,
None
]
*
p
*
(
1
-
p
)),
rtol
=
0.1
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"VonMises dispatch requires numpyro"
)
def
test_vonmises_mu_outside_circle
():
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# We test that the random draws from the JAX dispatch work as expected in these cases
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
mu
=
np
.
array
([
-
30
,
40
])
kappa
=
np
.
array
([
100
,
10
])
g
=
at
.
random
.
vonmises
(
mu
,
kappa
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
(
mu
+
np
.
pi
)
%
(
2.0
*
np
.
pi
)
-
np
.
pi
,
rtol
=
0.1
)
# Circvar only does the correct thing in more recent versions of Scipy
# https://github.com/scipy/scipy/pull/5747
# np.testing.assert_allclose(
# stats.circvar(samples, axis=0),
# 1 - special.iv(1, kappa) / special.iv(0, kappa),
# rtol=0.1,
# )
# For now simple compare with std from numpy draws
rng
=
np
.
random
.
default_rng
(
123
)
ref_samples
=
rng
.
vonmises
(
mu
,
kappa
,
size
=
(
10
_000
,
2
))
np
.
testing
.
assert_allclose
(
np
.
std
(
samples
,
axis
=
0
),
np
.
std
(
ref_samples
,
axis
=
0
),
rtol
=
0.1
)
def
test_random_unimplemented
():
"""Compiling a graph with a non-supported `RandomVariable` should
raise an error.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论