提交 56637af8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Luciano Paz

Implement vectorize_node dispatch for some forms of Join

上级 caa580bb
......@@ -23,7 +23,7 @@ from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
......@@ -42,7 +42,7 @@ from pytensor.tensor import (
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
DimShuffle,
Elemwise,
......@@ -2662,6 +2662,36 @@ def join(axis, *tensors_list):
return join_(axis, *tensors_list)
@_vectorize_node.register(Join)
def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
original_axis, *old_inputs = node.inputs
# We can vectorize join as a shifted axis on the batch inputs if:
# 1. The batch axis is a constant and has not changed
# 2. All inputs are batched with the same broadcastable pattern
if (
original_axis.type.ndim == 0
and isinstance(original_axis, Constant)
and equal_computations([original_axis], [batch_axis])
):
batch_ndims = {
batch_input.type.ndim - old_input.type.ndim
for batch_input, old_input in zip(batch_inputs, old_inputs)
}
if len(batch_ndims) == 1:
[batch_ndim] = batch_ndims
batch_bcast = batch_inputs[0].type.broadcastable[:batch_ndim]
if all(
batch_input.type.broadcastable[:batch_ndim] == batch_bcast
for batch_input in batch_inputs[1:]
):
original_ndim = node.outputs[0].type.ndim
original_axis = normalize_axis_index(original_axis.data, original_ndim)
batch_axis = original_axis + batch_ndim
return op.make_node(batch_axis, *batch_inputs)
return vectorize_node_fallback(op, node, batch_axis, *batch_inputs)
def roll(x, shift, axis=None):
"""
Convenience function to roll TensorTypes along the given axis.
......
......@@ -10,6 +10,7 @@ import pytensor.scalar as ps
import pytensor.tensor.basic as ptb
import pytensor.tensor.math as ptm
from pytensor import compile, config, function, shared
from pytensor.compile import SharedVariable
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp
......@@ -4565,3 +4566,37 @@ def test_vectorize_extract_diag():
vectorize_pt(x_test),
vectorize_np(x_test),
)
@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
@config.change_flags(cxx="") # C code not needed
def test_vectorize_join(axis, broadcasting_y):
# Signature for join along intermediate axis
signature = "(a,b1,c),(a,b2,c)->(a,b,c)"
def core_pt(x, y):
return join(axis, x, y)
def core_np(x, y):
return np.concatenate([x, y], axis=axis.eval())
x = tensor(shape=(4, 2, 3, 5))
y_shape = {"none": (4, 2, 3, 5), "implicit": (2, 3, 5), "explicit": (1, 2, 3, 5)}
y = tensor(shape=y_shape[broadcasting_y])
vectorize_pt = function([x, y], vectorize(core_pt, signature=signature)(x, y))
blockwise_needed = isinstance(axis, SharedVariable) or broadcasting_y != "none"
has_blockwise = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
assert has_blockwise == blockwise_needed
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论