提交 8ae14c20 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement vectorize_node for CheckAndRaise Op

上级 31a4df60
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic from pytensor.link.c.type import Generic
...@@ -198,3 +199,21 @@ class Assert(CheckAndRaise): ...@@ -198,3 +199,21 @@ class Assert(CheckAndRaise):
assert_op = Assert() assert_op = Assert()
@_vectorize_node.register(CheckAndRaise)
def vectorize_check_and_raise(op, node, batch_x, batch_cond):
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.shape import shape_padright
batch_cond_dims = batch_cond.type.ndim
if batch_cond_dims:
out = op(batch_x, batch_cond.all())
# Condition may broadcast batch dims of x
# We broadcast after the Check Op, so it can be removed more easily if not needed
x_core_ndim = node.inputs[0].type.ndim
batch_out, _ = broadcast_arrays(out, shape_padright(batch_cond, x_core_ndim))
return batch_out.owner
else:
return op.make_node(batch_x, batch_cond)
...@@ -5,10 +5,13 @@ import scipy.sparse ...@@ -5,10 +5,13 @@ import scipy.sparse
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.mode import OPT_FAST_RUN, Mode from pytensor.compile.mode import OPT_FAST_RUN, Mode
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import ScalarType, float64 from pytensor.scalar.basic import ScalarType, float64
from pytensor.sparse import as_sparse_variable from pytensor.sparse import as_sparse_variable
from pytensor.tensor.basic import second
from pytensor.tensor.elemwise import DimShuffle
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -184,3 +187,68 @@ def test_CheckAndRaise_sparse_variable(): ...@@ -184,3 +187,68 @@ def test_CheckAndRaise_sparse_variable():
a2 = check_and_raise(aspe1, aspe2.sum() > 2) a2 = check_and_raise(aspe1, aspe2.sum() > 2)
with pytest.raises(ValueError, match="sparse_check"): with pytest.raises(ValueError, match="sparse_check"):
a2.sum().eval() a2.sum().eval()
@pytensor.config.change_flags(cxx="") # For speed-up
def test_vectorize():
floatX = pytensor.config.floatX
x = pt.vector("x")
y = pt.vector("y")
cond = pt.all(y >= 0)
out = assert_op(x, cond)
batch_x = pt.matrix("batch_x", shape=(2, None))
batch_y = pt.matrix("batch_y", shape=(2, None))
test_x = np.arange(3).astype(floatX)
test_y = np.arange(4).astype(floatX)
test_batch_x = np.arange(6).reshape(2, 3).astype(floatX)
test_batch_y = np.arange(8).reshape(2, 4).astype(floatX)
# Only x is batched
vect_out = vectorize_graph(out, {x: batch_x, y: y})
assert vect_out.type.shape == (2, None)
assert isinstance(vect_out.owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, y: test_y}),
test_batch_x,
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, y: -test_y})
# Only y is batched
vect_out = vectorize_graph(out, {x: x, y: batch_y})
assert vect_out.type.shape == (2, None)
assert vect_out.owner.op == second # broadcast
assert isinstance(vect_out.owner.inputs[1].owner.op, DimShuffle)
assert isinstance(vect_out.owner.inputs[1].owner.inputs[0].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({x: test_x, batch_y: test_batch_y}),
np.broadcast_to(test_x, (2, *test_x.shape)),
)
with pytest.raises(AssertionError):
vect_out.eval({x: test_x, batch_y: -test_batch_y})
# Both x, and y are batched
vect_out = vectorize_graph(out, {x: batch_x, y: batch_y})
assert vect_out.type.shape == (2, None)
assert vect_out.owner.op == second
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
test_batch_x,
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})
# Both x, and y are batched and broadcast each other
vect_out = vectorize_graph(out, {x: batch_x[:, None, :], y: batch_y[None, :, :]})
assert vect_out.type.shape == (2, 2, None)
assert vect_out.owner.op == second
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
np.broadcast_to(test_batch_x[:, None, :], (2, *test_batch_x.shape)),
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论