提交 0f802ab2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add mvnormal logp dlogp benchmark test

上级 9c06de2b
......@@ -9,10 +9,10 @@ from pytensor import config, function
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
def test_vectorize_blockwise():
......@@ -320,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester):
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"
@pytest.mark.parametrize(
"mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}"
)
@pytest.mark.parametrize(
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
)
def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark):
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))
value_batch_shape = mu_batch_shape
if len(cov_batch_shape) > len(mu_batch_shape):
value_batch_shape = cov_batch_shape
value = tensor("value", shape=(*value_batch_shape, 10))
mu = tensor("mu", shape=(*mu_batch_shape, 10))
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))
test_values = [
rng.normal(size=value.type.shape),
rng.normal(size=mu.type.shape),
np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)),
]
chol_cov = cholesky(cov, lower=True, on_error="raise")
delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
diag = diagonal(chol_cov, axis1=-2, axis2=-1)
logdet = log(diag).sum(axis=-1)
k = value.shape[-1]
norm = -0.5 * k * (np.log(2 * np.pi))
logp = norm - 0.5 * quaddist - logdet
dlogp = grad(logp.sum(), wrt=[value, mu, cov])
fn = pytensor.function([value, mu, cov], [logp, *dlogp])
benchmark(fn, *test_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论