提交 861f95c2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Add test for Blockwise of COp with params

上级 35f0df96
...@@ -9,10 +9,10 @@ from pytensor import config, function ...@@ -9,10 +9,10 @@ from pytensor import config, function
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.shape import Shape
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -362,11 +362,20 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm ...@@ -362,11 +362,20 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values) benchmark(fn, *test_values)
def test_op_with_params(): def test_cop_with_params():
matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)") matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
x = tensor("x", shape=(5, None, None), dtype="float64") x = tensor("x", shape=(5, None, None), dtype="float64")
x_shape = matrix_shape_blockwise(x) x_shape = matrix_assert(x, (x >= 0).all())
fn = pytensor.function([x], x_shape) fn = pytensor.function([x], x_shape)
pytensor.dprint(fn) [fn_out] = fn.maker.fgraph.outputs
# Assert blockwise assert fn_out.owner.op == matrix_assert, "Blockwise should be in final graph"
print(fn(np.zeros((5, 3, 2))))
np.testing.assert_allclose(
fn(np.zeros((5, 3, 2))),
np.zeros((5, 3, 2)),
)
with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论