Unverified 提交 f12bea6c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Handle MvNormal method in Op call (#1252)

上级 7f031251
...@@ -865,7 +865,7 @@ class MvNormalRV(RandomVariable): ...@@ -865,7 +865,7 @@ class MvNormalRV(RandomVariable):
) )
self.method = method self.method = method
def __call__(self, mean, cov, size=None, **kwargs): def __call__(self, mean, cov, size=None, method=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution. r""" "Draw samples from a multivariate normal distribution.
Signature Signature
...@@ -888,6 +888,12 @@ class MvNormalRV(RandomVariable): ...@@ -888,6 +888,12 @@ class MvNormalRV(RandomVariable):
is specified, a single `N`-dimensional sample is returned. is specified, a single `N`-dimensional sample is returned.
""" """
if method is not None and method != self.method:
# Recreate Op with the new method
props = self._props_dict()
props["method"] = method
new_op = type(self)(**props)
return new_op.__call__(mean, cov, size=size, method=method, **kwargs)
return super().__call__(mean, cov, size=size, **kwargs) return super().__call__(mean, cov, size=size, **kwargs)
def rng_fn(self, rng, mean, cov, size): def rng_fn(self, rng, mean, cov, size):
......
...@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement, ChoiceWithoutReplacement,
MvNormalRV,
PermutationRV, PermutationRV,
_gamma, _gamma,
bernoulli, bernoulli,
...@@ -707,7 +706,7 @@ def create_mvnormal_cov_decomposition_method_test(mode): ...@@ -707,7 +706,7 @@ def create_mvnormal_cov_decomposition_method_test(mode):
[0, 0, 0], [0, 0, 0],
] ]
rng = shared(np.random.default_rng(675)) rng = shared(np.random.default_rng(675))
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,)) draws = multivariate_normal(mean, cov, method=method, size=(10_000,), rng=rng)
assert draws.owner.op.method == method assert draws.owner.op.method == method
# JAX doesn't raise errors at runtime # JAX doesn't raise errors at runtime
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论