提交 7f623fef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Vectorize make_vector

上级 5fd729d0
......@@ -1890,6 +1890,23 @@ def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)
@_vectorize_node.register
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
from pytensor.tensor.extra_ops import broadcast_arrays
# Check if we need to broadcast at all
bcast_pattern = batch_inputs[0].type.broadcastable
if not all(
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
):
batch_inputs = broadcast_arrays(*batch_inputs)
# Join along the last axis
new_out = stack(batch_inputs, axis=-1)
return new_out.owner
def transfer(var, target):
"""
Return a version of `var` transferred to `target`.
......@@ -2690,6 +2707,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
# We can vectorize join as a shifted axis on the batch inputs if:
# 1. The batch axis is a constant and has not changed
# 2. All inputs are batched with the same broadcastable pattern
# TODO: We can relax the second condition by broadcasting the batch dimensions
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
# Or otherwise by calling `broadcast_to` for each tensor that needs it
if (
original_axis.type.ndim == 0
and isinstance(original_axis, Constant)
......
......@@ -4577,6 +4577,46 @@ def test_vectorize_extract_diag():
)
@pytest.mark.parametrize(
"batch_shapes",
[
((3,),), # edge case of make_vector with a single input
((), (), ()), # Useless
((3,), (3,), (3,)), # No broadcasting needed
((3,), (5, 3), ()), # Broadcasting needed
],
)
def test_vectorize_make_vector(batch_shapes):
n_inputs = len(batch_shapes)
input_sig = ",".join(["()"] * n_inputs)
signature = f"{input_sig}->({n_inputs})" # Something like "(),(),()->(3)"
def core_pt(*scalars):
out = stack(scalars)
out.dprint()
return out
def core_np(*scalars):
return np.stack(scalars)
tensors = [tensor(shape=shape) for shape in batch_shapes]
vectorize_pt = function(tensors, vectorize(core_pt, signature=signature)(*tensors))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
test_values = [
np.random.normal(size=tensor.type.shape).astype(tensor.type.dtype)
for tensor in tensors
]
np.testing.assert_allclose(
vectorize_pt(*test_values),
np.vectorize(core_np, signature=signature)(*test_values),
)
@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
@config.change_flags(cxx="") # C code not needed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论