Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
383d4efe
提交
383d4efe
authored
12月 03, 2022
作者:
Rémi Louf
提交者:
Thomas Wiecki
12月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add tests for JAX `RandomVariable` implementations
上级
d647578c
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
366 行增加
和
57 行删除
+366
-57
test_basic.py
tests/link/jax/test_basic.py
+5
-0
test_random.py
tests/link/jax/test_random.py
+361
-57
没有找到文件。
tests/link/jax/test_basic.py
浏览文件 @
383d4efe
...
@@ -207,3 +207,8 @@ def test_jax_checkandraise():
...
@@ -207,3 +207,8 @@ def test_jax_checkandraise():
with
pytest
.
warns
(
UserWarning
):
with
pytest
.
warns
(
UserWarning
):
function
((
p
,),
res
,
mode
=
jax_mode
)
function
((
p
,),
res
,
mode
=
jax_mode
)
def
set_test_value
(
x
,
v
):
x
.
tag
.
test_value
=
v
return
x
tests/link/jax/test_random.py
浏览文件 @
383d4efe
...
@@ -2,61 +2,345 @@ import re
...
@@ -2,61 +2,345 @@ import re
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
packaging.version
import
parse
as
version_parse
import
scipy.stats
as
stats
import
pytensor
import
pytensor
import
pytensor.tensor
as
at
import
pytensor.tensor
as
at
import
pytensor.tensor.random
as
aer
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.
configdefaults
import
config
from
pytensor.
graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.random.basic
import
RandomVariable
from
pytensor.tensor.random.basic
import
RandomVariable
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.random.utils
import
RandomStream
from
tests.link.jax.test_basic
import
compare_jax_and_py
,
jax_mode
from
tests.link.jax.test_basic
import
compare_jax_and_py
,
jax_mode
,
set_test_value
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
@pytest.mark.xfail
(
def
test_random_RandomStream
():
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.26"
),
"""Two successive calls of a compiled graph using `RandomStream` should
reason
=
"JAX samplers require concrete/static shape values?"
,
return different values.
)
"""
srng
=
RandomStream
(
seed
=
123
)
out
=
srng
.
normal
()
-
srng
.
normal
()
with
pytest
.
warns
(
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
,
):
fn
=
function
([],
out
,
mode
=
jax_mode
)
jax_res_1
=
fn
()
jax_res_2
=
fn
()
assert
not
np
.
array_equal
(
jax_res_1
,
jax_res_2
)
@pytest.mark.parametrize
(
"rng_ctor"
,
(
np
.
random
.
RandomState
,
np
.
random
.
default_rng
))
def
test_random_updates
(
rng_ctor
):
original_value
=
rng_ctor
(
seed
=
98
)
rng
=
shared
(
original_value
,
name
=
"original_rng"
,
borrow
=
False
)
next_rng
,
x
=
at
.
random
.
normal
(
name
=
"x"
,
rng
=
rng
)
.
owner
.
outputs
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f
=
pytensor
.
function
([],
[
x
],
updates
=
{
rng
:
next_rng
},
mode
=
jax_mode
)
assert
f
()
!=
f
()
# Check that original rng variable content was not overwritten when calling jax_typify
assert
all
(
a
==
b
if
not
isinstance
(
a
,
np
.
ndarray
)
else
np
.
array_equal
(
a
,
b
)
for
a
,
b
in
zip
(
rng
.
get_value
()
.
__getstate__
(),
original_value
.
__getstate__
())
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"
at_dist, dist_params, rng, size
"
,
"
rv_op, dist_params, base_size, cdf_name, params_conv
"
,
[
[
(
(
at
.
random
.
normal
,
aer
.
beta
,
(),
[
shared
(
np
.
random
.
RandomState
(
123
)),
set_test_value
(
10000
,
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"beta"
,
lambda
*
args
:
args
,
),
(
aer
.
cauchy
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"cauchy"
,
lambda
*
args
:
args
,
),
(
aer
.
exponential
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
],
(
2
,),
"expon"
,
lambda
*
args
:
(
0
,
args
[
0
]),
),
(
aer
.
gamma
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"gamma"
,
lambda
a
,
b
:
(
a
,
0.0
,
b
),
),
(
aer
.
laplace
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
)),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
)),
],
(
2
,),
"laplace"
,
lambda
*
args
:
args
,
),
(
aer
.
logistic
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"logistic"
,
lambda
*
args
:
args
,
),
(
aer
.
normal
,
[
set_test_value
(
at
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"norm"
,
lambda
*
args
:
args
,
),
(
aer
.
pareto
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
)
],
(
2
,),
"pareto"
,
lambda
*
args
:
args
,
),
(
aer
.
poisson
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1000.0
,
2000.0
],
dtype
=
np
.
float64
),
),
],
(
2
,),
"poisson"
,
lambda
*
args
:
args
,
),
),
(
(
at
.
random
.
normal
,
aer
.
randint
,
[
set_test_value
(
at
.
lscalar
(),
np
.
array
(
0
,
dtype
=
np
.
int64
),
),
set_test_value
(
# high-value necessary since test on cdf
at
.
lscalar
(),
np
.
array
(
1000
,
dtype
=
np
.
int64
),
),
],
(),
(),
shared
(
np
.
random
.
default_rng
(
123
)),
"randint"
,
10000
,
lambda
*
args
:
args
,
),
(
aer
.
uniform
,
[
set_test_value
(
at
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
set_test_value
(
at
.
dscalar
(),
np
.
array
(
1000.0
,
dtype
=
np
.
float64
),
),
],
(
2
,),
"uniform"
,
lambda
*
args
:
args
,
),
),
],
],
)
)
def
test_random_stats
(
at_dist
,
dist_params
,
rng
,
size
):
def
test_random_RandomVariable
(
rv_op
,
dist_params
,
base_size
,
cdf_name
,
params_conv
):
# The RNG states are not 1:1, so the best we can do is check some summary
"""The JAX samplers are not one-to-one with NumPy samplers so we
# statistics of the samples
need to use a statistical test to make sure that the transpilation
out
=
at
.
random
.
normal
(
*
dist_params
,
rng
=
rng
,
size
=
size
)
is correct.
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
Parameters
----------
rv_op
The transpiled `RandomVariable` `Op`.
dist_params
The parameters passed to the op.
def
assert_fn
(
x
,
y
):
"""
(
x
,)
=
x
rng
=
shared
(
np
.
random
.
RandomState
(
29402
))
(
y
,)
=
y
g
=
rv_op
(
*
dist_params
,
size
=
(
10
_000
,)
+
base_size
,
rng
=
rng
)
assert
x
.
dtype
.
kind
==
y
.
dtype
.
kind
g_fn
=
function
(
dist_params
,
g
,
mode
=
jax_mode
)
samples
=
g_fn
(
*
[
i
.
tag
.
test_value
for
i
in
g_fn
.
maker
.
fgraph
.
inputs
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
]
)
d
=
2
if
config
.
floatX
==
"float64"
else
1
bcast_dist_args
=
np
.
broadcast_arrays
(
*
[
i
.
tag
.
test_value
for
i
in
dist_params
])
np
.
testing
.
assert_array_almost_equal
(
np
.
abs
(
x
.
mean
()),
np
.
abs
(
y
.
mean
()),
d
)
compare_jax_and_py
(
fgraph
,
[],
assert_fn
=
assert_fn
)
for
idx
in
np
.
ndindex
(
*
base_size
):
cdf_params
=
params_conv
(
*
tuple
(
arg
[
idx
]
for
arg
in
bcast_dist_args
))
test_res
=
stats
.
cramervonmises
(
samples
[(
Ellipsis
,)
+
idx
],
cdf_name
,
args
=
cdf_params
)
assert
test_res
.
pvalue
>
0.1
@pytest.mark.parametrize
(
"size"
,
[(),
(
4
,)])
def
test_random_bernoulli
(
size
):
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
bernoulli
(
0.5
,
size
=
(
1000
,)
+
size
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
0.5
,
1
)
def
test_random_mvnormal
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
mu
=
np
.
ones
(
4
)
cov
=
np
.
eye
(
4
)
g
=
at
.
random
.
multivariate_normal
(
mu
,
cov
,
size
=
(
10000
,),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
mu
,
atol
=
0.1
)
@pytest.mark.parametrize
(
"parameter, size"
,
[
(
np
.
ones
(
4
),
()),
(
np
.
ones
(
4
),
(
2
,
4
)),
],
)
def
test_random_dirichlet
(
parameter
,
size
):
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
dirichlet
(
parameter
,
size
=
(
1000
,)
+
size
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
0.5
,
1
)
def
test_random_choice
():
# Elements are picked at equal frequency
num_samples
=
10000
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
choice
(
np
.
arange
(
4
),
size
=
num_samples
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
np
.
sum
(
samples
==
3
)
/
num_samples
,
0.25
,
2
)
# `replace=False` produces unique results
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
choice
(
np
.
arange
(
100
),
replace
=
False
,
size
=
99
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
assert
len
(
np
.
unique
(
samples
))
==
99
# We can pass an array with probabilities
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
choice
(
np
.
arange
(
3
),
p
=
np
.
array
([
1.0
,
0.0
,
0.0
]),
size
=
10
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
,
np
.
zeros
(
10
))
def
test_random_categorical
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
categorical
(
0.25
*
np
.
ones
(
4
),
size
=
(
10000
,
4
),
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
6
/
4
,
1
)
def
test_random_permutation
():
array
=
np
.
arange
(
4
)
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
g
=
at
.
random
.
permutation
(
array
,
rng
=
rng
)
g_fn
=
function
([],
g
,
mode
=
jax_mode
)
permuted
=
g_fn
()
with
pytest
.
raises
(
AssertionError
):
np
.
testing
.
assert_allclose
(
array
,
permuted
)
def
test_random_unimplemented
():
def
test_random_unimplemented
():
"""Compiling a graph with a non-supported `RandomVariable` should
raise an error.
"""
class
NonExistentRV
(
RandomVariable
):
class
NonExistentRV
(
RandomVariable
):
name
=
"non-existent"
name
=
"non-existent"
ndim_supp
=
0
ndim_supp
=
0
...
@@ -78,38 +362,58 @@ def test_random_unimplemented():
...
@@ -78,38 +362,58 @@ def test_random_unimplemented():
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
fgraph
,
[])
def
test_RandomStream
():
def
test_random_custom_implementation
():
srng
=
RandomStream
(
seed
=
123
)
"""We can register a JAX implementation for user-defined `RandomVariable`s"""
out
=
srng
.
normal
()
-
srng
.
normal
()
with
pytest
.
warns
(
class
CustomRV
(
RandomVariable
):
UserWarning
,
name
=
"non-existent"
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
,
ndim_supp
=
0
):
ndims_params
=
[]
fn
=
function
([],
out
,
mode
=
jax_mode
)
dtype
=
"floatX"
jax_res_1
=
fn
()
jax_res_2
=
fn
()
assert
not
np
.
array_equal
(
jax_res_1
,
jax_res_2
)
def
__call__
(
self
,
size
=
None
,
**
kwargs
):
return
super
()
.
__call__
(
size
=
size
,
**
kwargs
)
def
rng_fn
(
cls
,
rng
,
size
):
return
0
@pytest.mark.parametrize
(
"rng_ctor"
,
(
np
.
random
.
RandomState
,
np
.
random
.
default_rng
))
from
pytensor.link.jax.dispatch.random
import
jax_sample_fn
def
test_random_updates
(
rng_ctor
):
original_value
=
rng_ctor
(
seed
=
98
)
rng
=
shared
(
original_value
,
name
=
"original_rng"
,
borrow
=
False
)
next_rng
,
x
=
at
.
random
.
normal
(
name
=
"x"
,
rng
=
rng
)
.
owner
.
outputs
with
pytest
.
warns
(
@jax_sample_fn.register
(
CustomRV
)
UserWarning
,
def
jax_sample_fn_custom
(
op
):
match
=
re
.
escape
(
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
"The RandomType SharedVariables [original_rng] will not be used"
return
(
rng
,
0
)
),
):
f
=
pytensor
.
function
([],
[
x
],
updates
=
{
rng
:
next_rng
},
mode
=
jax_mode
)
assert
f
()
!=
f
()
# Check that original rng variable content was not overwritten when calling jax_typify
return
sample_fn
assert
all
(
a
==
b
if
not
isinstance
(
a
,
np
.
ndarray
)
else
np
.
array_equal
(
a
,
b
)
nonexistentrv
=
CustomRV
()
for
a
,
b
in
zip
(
rng
.
get_value
()
.
__getstate__
(),
original_value
.
__getstate__
())
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
)
out
=
nonexistentrv
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
compare_jax_and_py
(
fgraph
,
[])
def
test_random_concrete_shape
():
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
f
=
at
.
random
.
normal
(
0
,
1
,
size
=
(
3
,),
rng
=
rng
)
g
=
at
.
random
.
normal
(
f
,
1
,
size
=
x_at
.
shape
,
rng
=
rng
)
g_fn
=
function
([
x_at
],
g
,
mode
=
jax_mode
)
_
=
g_fn
(
np
.
ones
((
2
,
3
)))
# This should compile, and `size_at` be passed to the list of `static_argnums`.
with
pytest
.
raises
(
NotImplementedError
):
size_at
=
at
.
scalar
()
g
=
at
.
random
.
normal
(
f
,
1
,
size
=
size_at
,
rng
=
rng
)
g_fn
=
function
([
size_at
],
g
,
mode
=
jax_mode
)
_
=
g_fn
(
10
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论