提交 28572bfe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement TODO test

上级 0462c7a1
...@@ -26,7 +26,7 @@ from pytensor.xtensor.random import ( ...@@ -26,7 +26,7 @@ from pytensor.xtensor.random import (
multivariate_normal, multivariate_normal,
normal, normal,
) )
from pytensor.xtensor.vectorization import XRV from pytensor.xtensor.vectorization import XRV, vectorize_graph
from tests.xtensor.util import check_vectorization from tests.xtensor.util import check_vectorization
...@@ -462,5 +462,14 @@ def test_xrv_vectorize(): ...@@ -462,5 +462,14 @@ def test_xrv_vectorize():
def test_xrv_batch_extra_dim_vectorize(): def test_xrv_batch_extra_dim_vectorize():
# TODO: Check it raises NotImplementedError when we try to batch the extra_dim of an xrv extra_size = xtensor("extra_size", dims=(), dtype=int)
pass mu = xtensor("mu", dims=("a",), shape=(3,))
out = normal(mu, 1, extra_dims={"extra": extra_size})
assert out.type.dims == ("extra", "a")
# Batching extra_size should raise NotImplementedError
batch_extra_size = xtensor(
"batch_extra_size", dims=("batch",), shape=(2,), dtype=int
)
with pytest.raises(NotImplementedError, match="batched extra_dim_lengths"):
vectorize_graph([out], {extra_size: batch_extra_size})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论