Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2aecb956
提交
2aecb956
authored
2月 12, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 13, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow decomposition methods in MvNormal
上级
2823dfca
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
113 行增加
和
18 行删除
+113
-18
random.py
pytensor/link/jax/dispatch/random.py
+14
-1
random.py
pytensor/link/numba/dispatch/random.py
+16
-3
basic.py
pytensor/tensor/random/basic.py
+27
-14
test_random.py
tests/link/jax/test_random.py
+6
-0
test_random.py
tests/link/numba/test_random.py
+6
-0
test_basic.py
tests/tensor/random/test_basic.py
+44
-0
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
2aecb956
...
...
@@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
@jax_sample_fn.register
(
ptr
.
BetaRV
)
@jax_sample_fn.register
(
ptr
.
DirichletRV
)
@jax_sample_fn.register
(
ptr
.
PoissonRV
)
@jax_sample_fn.register
(
ptr
.
MvNormalRV
)
def
jax_sample_fn_generic
(
op
,
node
):
"""Generic JAX implementation of random variables."""
name
=
op
.
name
...
...
@@ -173,6 +172,20 @@ def jax_sample_fn_loc_scale(op, node):
return
sample_fn
@jax_sample_fn.register
(
ptr
.
MvNormalRV
)
def
jax_sample_mvnormal
(
op
,
node
):
def
sample_fn
(
rng
,
size
,
dtype
,
mean
,
cov
):
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
sample
=
jax
.
random
.
multivariate_normal
(
sampling_key
,
mean
,
cov
,
shape
=
size
,
dtype
=
dtype
,
method
=
op
.
method
)
rng
[
"jax_state"
]
=
rng_key
return
(
rng
,
sample
)
return
sample_fn
@jax_sample_fn.register
(
ptr
.
BernoulliRV
)
def
jax_sample_fn_bernoulli
(
op
,
node
):
"""JAX implementation of `BernoulliRV`."""
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
2aecb956
...
...
@@ -144,11 +144,24 @@ def core_CategoricalRV(op, node):
@numba_core_rv_funcify.register
(
ptr
.
MvNormalRV
)
def
core_MvNormalRV
(
op
,
node
):
method
=
op
.
method
@numba_basic.numba_njit
def
random_fn
(
rng
,
mean
,
cov
):
chol
=
np
.
linalg
.
cholesky
(
cov
)
stdnorm
=
rng
.
normal
(
size
=
cov
.
shape
[
-
1
])
return
np
.
dot
(
chol
,
stdnorm
)
+
mean
if
method
==
"cholesky"
:
A
=
np
.
linalg
.
cholesky
(
cov
)
elif
method
==
"svd"
:
A
,
s
,
_
=
np
.
linalg
.
svd
(
cov
)
A
*=
np
.
sqrt
(
s
)[
None
,
:]
else
:
w
,
A
=
np
.
linalg
.
eigh
(
cov
)
A
*=
np
.
sqrt
(
w
)[
None
,
:]
out
=
rng
.
normal
(
size
=
cov
.
shape
[
-
1
])
# out argument not working correctly: https://github.com/numba/numba/issues/9924
out
[:]
=
np
.
dot
(
A
,
out
)
out
+=
mean
return
out
random_fn
.
handles_out
=
True
return
random_fn
...
...
pytensor/tensor/random/basic.py
浏览文件 @
2aecb956
import
abc
import
warnings
from
typing
import
Literal
import
numpy
as
np
import
scipy.stats
as
stats
from
numpy
import
broadcast_shapes
as
np_broadcast_shapes
from
numpy
import
einsum
as
np_einsum
from
numpy
import
sqrt
as
np_sqrt
from
numpy.linalg
import
cholesky
as
np_cholesky
from
numpy.linalg
import
eigh
as
np_eigh
from
numpy.linalg
import
svd
as
np_svd
import
pytensor
from
pytensor.tensor
import
get_vector_length
,
specify_shape
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.math
import
sqrt
...
...
@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
signature
=
"(n),(n,n)->(n)"
dtype
=
"floatX"
_print_name
=
(
"MultivariateNormal"
,
"
\\
operatorname{MultivariateNormal}"
)
__props__
=
(
"name"
,
"signature"
,
"dtype"
,
"inplace"
,
"method"
)
def
__call__
(
self
,
mean
=
None
,
cov
=
None
,
size
=
None
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
method
:
Literal
[
"cholesky"
,
"svd"
,
"eigh"
],
**
kwargs
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
if
method
not
in
(
"cholesky"
,
"svd"
,
"eigh"
):
raise
ValueError
(
f
"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
)
self
.
method
=
method
def
__call__
(
self
,
mean
,
cov
,
size
=
None
,
**
kwargs
):
r""" "Draw samples from a multivariate normal distribution.
Signature
...
...
@@ -876,33 +888,34 @@ class MvNormalRV(RandomVariable):
is specified, a single `N`-dimensional sample is returned.
"""
dtype
=
pytensor
.
config
.
floatX
if
self
.
dtype
==
"floatX"
else
self
.
dtype
if
mean
is
None
:
mean
=
np
.
array
([
0.0
],
dtype
=
dtype
)
if
cov
is
None
:
cov
=
np
.
array
([[
1.0
]],
dtype
=
dtype
)
return
super
()
.
__call__
(
mean
,
cov
,
size
=
size
,
**
kwargs
)
@classmethod
def
rng_fn
(
cls
,
rng
,
mean
,
cov
,
size
):
def
rng_fn
(
self
,
rng
,
mean
,
cov
,
size
):
if
size
is
None
:
size
=
np_broadcast_shapes
(
mean
.
shape
[:
-
1
],
cov
.
shape
[:
-
2
])
chol
=
np_cholesky
(
cov
)
if
self
.
method
==
"cholesky"
:
A
=
np_cholesky
(
cov
)
elif
self
.
method
==
"svd"
:
A
,
s
,
_
=
np_svd
(
cov
)
A
*=
np_sqrt
(
s
,
out
=
s
)[
...
,
None
,
:]
else
:
w
,
A
=
np_eigh
(
cov
)
A
*=
np_sqrt
(
w
,
out
=
w
)[
...
,
None
,
:]
out
=
rng
.
normal
(
size
=
(
*
size
,
mean
.
shape
[
-
1
]))
np_einsum
(
"...ij,...j->...i"
,
# numpy doesn't have a batch matrix-vector product
chol
,
A
,
out
,
out
=
out
,
optimize
=
False
,
# Nothing to optimize with two operands, skip costly setup
out
=
out
,
)
out
+=
mean
return
out
multivariate_normal
=
MvNormalRV
()
multivariate_normal
=
MvNormalRV
(
method
=
"cholesky"
)
class
DirichletRV
(
RandomVariable
):
...
...
tests/link/jax/test_random.py
浏览文件 @
2aecb956
...
...
@@ -18,6 +18,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester
,
batched_unweighted_choice_without_replacement_tester
,
batched_weighted_choice_without_replacement_tester
,
create_mvnormal_cov_decomposition_method_test
,
)
...
...
@@ -547,6 +548,11 @@ def test_random_mvnormal():
np
.
testing
.
assert_allclose
(
samples
.
mean
(
axis
=
0
),
mu
,
atol
=
0.1
)
test_mvnormal_cov_decomposition_method
=
create_mvnormal_cov_decomposition_method_test
(
"JAX"
)
@pytest.mark.parametrize
(
"parameter, size"
,
[
...
...
tests/link/numba/test_random.py
浏览文件 @
2aecb956
...
...
@@ -22,6 +22,7 @@ from tests.tensor.random.test_basic import (
batched_permutation_tester
,
batched_unweighted_choice_without_replacement_tester
,
batched_weighted_choice_without_replacement_tester
,
create_mvnormal_cov_decomposition_method_test
,
)
...
...
@@ -147,6 +148,11 @@ def test_multivariate_normal():
)
test_mvnormal_cov_decomposition_method
=
create_mvnormal_cov_decomposition_method_test
(
"NUMBA"
)
@pytest.mark.parametrize
(
"rv_op, dist_args, size"
,
[
...
...
tests/tensor/random/test_basic.py
浏览文件 @
2aecb956
...
...
@@ -19,6 +19,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.tensor
import
ones
,
stack
from
pytensor.tensor.random.basic
import
(
ChoiceWithoutReplacement
,
MvNormalRV
,
PermutationRV
,
_gamma
,
bernoulli
,
...
...
@@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature():
assert
s4
.
get_test_value
()
==
3
def
create_mvnormal_cov_decomposition_method_test
(
mode
):
@pytest.mark.parametrize
(
"psd"
,
(
True
,
False
))
@pytest.mark.parametrize
(
"method"
,
(
"cholesky"
,
"svd"
,
"eigh"
))
def
test_mvnormal_cov_decomposition_method
(
method
,
psd
):
mean
=
2
**
np
.
arange
(
3
)
if
psd
:
cov
=
[
[
1
,
0.5
,
-
1
],
[
0.5
,
2
,
0
],
[
-
1
,
0
,
3
],
]
else
:
cov
=
[
[
1
,
0.5
,
0
],
[
0.5
,
2
,
0
],
[
0
,
0
,
0
],
]
rng
=
shared
(
np
.
random
.
default_rng
(
675
))
draws
=
MvNormalRV
(
method
=
method
)(
mean
,
cov
,
rng
=
rng
,
size
=
(
10
_000
,))
assert
draws
.
owner
.
op
.
method
==
method
# JAX doesn't raise errors at runtime
if
not
psd
and
method
==
"cholesky"
:
if
mode
==
"JAX"
:
# JAX doesn't raise errors at runtime, instead it returns nan
np
.
isnan
(
draws
.
eval
(
mode
=
mode
))
.
all
()
else
:
with
pytest
.
raises
(
np
.
linalg
.
LinAlgError
):
draws
.
eval
(
mode
=
mode
)
else
:
draws_eval
=
draws
.
eval
(
mode
=
mode
)
np
.
testing
.
assert_allclose
(
np
.
mean
(
draws_eval
,
axis
=
0
),
mean
,
rtol
=
0.02
)
np
.
testing
.
assert_allclose
(
np
.
cov
(
draws_eval
,
rowvar
=
False
),
cov
,
atol
=
0.1
)
return
test_mvnormal_cov_decomposition_method
test_mvnormal_cov_decomposition_method
=
create_mvnormal_cov_decomposition_method_test
(
None
)
@pytest.mark.parametrize
(
"alphas, size"
,
[
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论