提交 87470065 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement dim-aware vectorize_graph

上级 9d99267c
......@@ -283,6 +283,13 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])]
"""
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if isinstance(outputs, Sequence):
seq_outputs = outputs
else:
......
import warnings
from collections.abc import Collection, Iterable
from collections.abc import Collection, Iterable, Sequence
from textwrap import dedent
import numpy as np
......@@ -1926,7 +1926,7 @@ def logspace(
def broadcast_to(
x: TensorVariable, shape: TensorVariable | tuple[Variable, ...]
x: TensorLike, shape: TensorLike | Sequence[TensorLike]
) -> TensorVariable:
"""Broadcast an array to a new shape.
......
......@@ -18,7 +18,9 @@ class XOp(Op):
def do_constant_folding(self, fgraph, node):
return False
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
......@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
......@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp):
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
"You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
)
new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)]
......@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
if new_x.ndim != old_x.ndim:
if new_dim is None:
raise NotImplementedError(
f"Vectorization of {self} with batched inputs not implemented, "
"as it can't infer new dimension labels"
f"Vectorization of {self} cannot infer the new dimension labels. "
"Use pytensor.xtensor.vectorization.vectorize_graph instead."
)
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
else:
return [self(new_x)]
......@@ -111,7 +119,7 @@ class Rename(XTypeCastOp):
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
......
......@@ -197,7 +197,7 @@ class Index(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])
def vectorize_node(self, node, new_x, *new_idxs):
def vectorize_node(self, node, new_x, *new_idxs, new_dim):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
......@@ -237,7 +237,7 @@ class IndexUpdate(XOp):
out = x.type()
return Apply(self, [x, y, *idxs], [out])
def vectorize_node(self, node, *new_inputs):
def vectorize_node(self, node, *new_inputs, new_dim):
# If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set(
chain.from_iterable(
......
......@@ -46,7 +46,7 @@ class XReduce(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]
......@@ -120,7 +120,7 @@ class XCumReduce(XOp):
out = x.type()
return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]
......
......@@ -68,7 +68,7 @@ class Stack(XOp):
)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]
......@@ -149,7 +149,7 @@ class UnStack(XOp):
)
return Apply(self, [x, *unstacked_lengths], [output])
def vectorize_node(self, node, new_x, *new_unstacked_length):
def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError(
......@@ -200,7 +200,7 @@ class Transpose(XOp):
)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
......@@ -318,7 +318,7 @@ class Concat(XOp):
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])
def vectorize_node(self, node, *new_inputs):
def vectorize_node(self, node, *new_inputs, new_dim):
return [self(*new_inputs)]
......@@ -402,7 +402,7 @@ class Squeeze(XOp):
)
return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]
......@@ -464,7 +464,7 @@ class ExpandDims(XOp):
)
return Apply(self, [x, size], [out])
def vectorize_node(self, node, new_x, new_size):
def vectorize_node(self, node, new_x, new_size, new_dim):
new_size = new_size.squeeze()
if new_size.type.ndim != 0:
raise NotImplementedError(
......@@ -567,7 +567,7 @@ class Broadcast(XOp):
return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
def vectorize_node(self, node, *new_inputs, new_dim):
if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := (
......
......@@ -3,10 +3,12 @@ import pytest
pytest.importorskip("xarray")
import re
import numpy as np
from pytensor import function
from pytensor.graph import vectorize_graph
from pytensor.graph import vectorize_graph as tensor_vectorize_graph
from pytensor.tensor import matrix, vector
from pytensor.xtensor.basic import (
Rename,
......@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import (
tensor_from_xtensor,
xtensor_from_tensor,
)
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import check_vectorization
......@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize():
t_batched = matrix("t_batched")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
NotImplementedError,
match=re.escape(
"cannot infer the new dimension labels. Use pytensor.xtensor.vectorization.vectorize_graph instead."
),
):
vectorize_graph([x], {t: t_batched})
tensor_vectorize_graph(x, {t: t_batched})
vec_x = vectorize_graph(x, {t: t_batched}, new_tensor_dims=("b",))
assert_equal_computations([vec_x], [as_xtensor(t_batched, dims=("b", "a"))])
def test_tensor_from_xtensor_vectorize():
......@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5))
y_batched = vectorize_graph(y, {x: x_batched})
y_batched = tensor_vectorize_graph(y, {x: x_batched})
# vectorize_graph should place output batch dimension on the left
assert y_batched.type.shape == (5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values])
......@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5))
# vectorize_graph can't handle multiple batch dimensions safely
with pytest.raises(NotImplementedError):
vectorize_graph(y, {x: x_batched})
tensor_vectorize_graph(y, {x: x_batched})
# xtensor vectorize_graph can handle this graph safely
y_batched = vectorize_graph(y, {x: x_batched})
assert y_batched.type.shape == (7, 5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("c", "b", "a").values])
......@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like
from pytensor.graph import vectorize_graph
from pytensor.tensor import scalar, vector
from pytensor.xtensor.shape import (
broadcast,
......@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import (
zeros_like,
)
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
......@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize():
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
def test_unstack_batch_length_vectorize():
......@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize():
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
import numpy as np
import pytest
from pytensor.tensor import TensorVariable, broadcast_to, tensor
from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations
class TestVectorizeGraph:
def test_pure_xtensor_graph(self):
x = xtensor("x", dims=("a",))
out = x + 1
x_new = xtensor("x_new", dims=("c", "a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = x_new.transpose("c", "b", "a") + 1
assert_equal_computations([out_vec], [expected])
def test_pure_tensor_graph(self):
x = tensor("x", shape=())
out = x + 1
x_new = tensor("x_new", shape=(5,))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 1
expected = x_new + 1
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_graph(self):
x = xtensor("x", dims=("a",))
t = x.values # Convert to TensorVariable
t2 = t + np.ones(1)
out = xtensor_from_tensor(t2, dims=("a",))
x_new = xtensor("x_new", dims=("a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("b", "a")
expected = as_xtensor(
x_new.transpose("b", "a").values + np.ones(1), dims=("b", "a")
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_multiple_inputs_graph(self):
x = xtensor("x", dims=("a",))
y = xtensor("y", dims=("a",))
t = x.values + y.values
out = xtensor_from_tensor(t, dims=("a",))
x_new = xtensor("x_new", dims=("a", "c"))
# Both inputs have the same batch dims
y_new = xtensor("y_new", dims=("c", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "a")
expected = as_xtensor(
(x_new.transpose("c", "a").values + y_new.transpose("c", "a").values),
dims=("c", "a"),
)
assert_equal_computations([out_vec], [expected])
# Inputs have different batch dims
y_new = xtensor("y_new", dims=("b", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = as_xtensor(
(
x_new.transpose("c", "a").values[:, None]
+ y_new.transpose("b", "a").values[None, :]
),
dims=("c", "b", "a"),
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_xtensor_graph(self):
x = tensor("x", shape=(3,))
t = as_xtensor(x, dims=("a",))
t2 = t + 1
out = t2.values
x_new = tensor("x_new", shape=(5, 3))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 2
expected = (as_xtensor(x_new, dims=("b", "a")) + 1).values
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs(self):
x = xtensor("x", dims=("a",), shape=(3,))
y = tensor("y", shape=(5,))
out = as_xtensor(y[2:], dims=("b",)) + x
x_new = xtensor("x_new", dims=("a", "d"), shape=(3, 7))
y_new = tensor("y_new", shape=(7, 5))
# New dimension of y is aligned with the new dimension of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["d"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "b", "a")
expected = as_xtensor(y_new[:, 2:], dims=("d", "b")) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
# New dimension of y is distinct from that of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["c"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "c", "b", "a")
# x introduced a new dimension "d" which causes y to be broadcasted
y_broadcasted = broadcast_to(
y_new, (x_new.sizes["d"], y_new.shape[0], y_new.shape[1])
)
expected = as_xtensor(
y_broadcasted[:, :, 2:], dims=("d", "c", "b")
) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs_complex_broadcasting(self):
a = xtensor("a", dims=("a",), shape=(3,))
b = xtensor("b", dims=("b"), shape=(5,))
y = tensor("y", shape=(7,))
z = tensor("z", shape=(11,))
out = a + b + y.sum() + z.sum()
assert out.dims == ("a", "b")
a_new = xtensor("a_new", dims=("a*", "a"), shape=(33, 3))
b_new = xtensor("b_new", dims=("b*", "b"), shape=(55, 5))
y_new = tensor("y_new", shape=(1, 55, 2, 1, 7))
z_new = tensor("z_new", shape=(33, 1, 1, 2, 11))
[out_vec] = vectorize_graph(
[out],
{a: a_new, b: b_new, y: y_new, z: z_new},
new_tensor_dims=["a*", "b*", "y*", "z*"],
)
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("a*", "b*", "y*", "z*", "a", "b")
batch_shape_truth = (
a_new.sizes["a*"],
b_new.sizes["b*"],
y_new.shape[2],
z_new.shape[3],
)
y_new_bcast = broadcast_to(y_new, (*batch_shape_truth, y_new.shape[4]))
z_new_bcast = broadcast_to(z_new, (*batch_shape_truth, z_new.shape[4]))
expected_out = (
(a_new + b_new)
+ as_xtensor(y_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
+ as_xtensor(z_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
).transpose("a*", "b*", "y*", "z*", ...)
assert_equal_computations([out_vec], [expected_out])
def test_invalid_cases(self):
x = xtensor("x", dims=("a",))
out = x + 1
# Missing xtensor dims
x_bad = xtensor("x_bad", dims=("b",)) # Missing "a"
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out], {x: x_bad})
# New xtensor dims that were present in original graph
y = xtensor("y", dims=("b",))
out2 = x + y
x_new_conflict = xtensor("x_new", dims=("a", "b"))
# "b" is new to x, but present in graph (in y)
with pytest.raises(ValueError, match="new dimensions that were present"):
vectorize_graph([out2], {x: x_new_conflict})
# Missing tensor dims
t = tensor("t", shape=(3,))
out_t = t + 1
# Replacement has fewer dims (rank 0)
t_bad_rank = tensor("t_bad", shape=())
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out_t], {t: t_bad_rank})
# Missing new_tensor_dims
t_new = tensor("t_new", shape=(5, 5, 3))
with pytest.raises(ValueError, match="You must specify `new_tensor_dims`"):
vectorize_graph([out_t], {t: t_new})
with pytest.raises(ValueError, match=r"but only .* were specified"):
vectorize_graph([out_t], {t: t_new}, new_tensor_dims=["a"])
# Excess new_tensor_dims
# Replacement adds 1 dim, but 2 are specified
t_new_1dim = tensor("t_new_1dim", shape=(5, 3))
with pytest.raises(ValueError, match="tensor dims were specified, but only"):
vectorize_graph([out_t], {t: t_new_1dim}, new_tensor_dims=["a", "b"])
......@@ -10,8 +10,8 @@ from xarray import DataArray
from xarray.testing import assert_allclose
from pytensor import function
from pytensor.graph import vectorize_graph
from pytensor.xtensor.type import XTensorType, as_xtensor
from pytensor.xtensor.vectorization import vectorize_graph
def xr_function(*args, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论