提交 861f95c2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Add test for Blockwise of COp with params

上级 35f0df96
......@@ -9,10 +9,10 @@ from pytensor import config, function
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.shape import Shape
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -362,11 +362,20 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
benchmark(fn, *test_values)
def test_op_with_params():
matrix_shape_blockwise = Blockwise(core_op=Shape(), signature="(x1,x2)->(s)")
def test_cop_with_params():
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
x = tensor("x", shape=(5, None, None), dtype="float64")
x_shape = matrix_shape_blockwise(x)
x_shape = matrix_assert(x, (x >= 0).all())
fn = pytensor.function([x], x_shape)
pytensor.dprint(fn)
# Assert blockwise
print(fn(np.zeros((5, 3, 2))))
[fn_out] = fn.maker.fgraph.outputs
assert fn_out.owner.op == matrix_assert, "Blockwise should be in final graph"
np.testing.assert_allclose(
fn(np.zeros((5, 3, 2))),
np.zeros((5, 3, 2)),
)
with pytest.raises(AssertionError):
fn(np.zeros((5, 3, 2)) - 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论