提交 87470065 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement dim-aware vectorize_graph

上级 9d99267c
...@@ -283,6 +283,13 @@ def vectorize_graph( ...@@ -283,6 +283,13 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])] # [array([-10., -11.]), array([10., 11.])]
""" """
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if isinstance(outputs, Sequence): if isinstance(outputs, Sequence):
seq_outputs = outputs seq_outputs = outputs
else: else:
......
import warnings import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable, Sequence
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
...@@ -1926,7 +1926,7 @@ def logspace( ...@@ -1926,7 +1926,7 @@ def logspace(
def broadcast_to( def broadcast_to(
x: TensorVariable, shape: TensorVariable | tuple[Variable, ...] x: TensorLike, shape: TensorLike | Sequence[TensorLike]
) -> TensorVariable: ) -> TensorVariable:
"""Broadcast an array to a new shape. """Broadcast an array to a new shape.
......
...@@ -18,7 +18,9 @@ class XOp(Op): ...@@ -18,7 +18,9 @@ class XOp(Op):
def do_constant_folding(self, fgraph, node): def do_constant_folding(self, fgraph, node):
return False return False
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]: def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}") raise NotImplementedError(f"Vectorized node not implemented for {self}")
...@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp): ...@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return input_shapes return input_shapes
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]: def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}") raise NotImplementedError(f"Vectorized node not implemented for {self}")
...@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp): ...@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)] return [xtensor_from_tensor(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1: if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError( raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. " f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time." "You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
) )
new_x = new_x.transpose(..., *old_x.dims) new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)] return [self(new_x)]
...@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp): ...@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [tensor_from_xtensor(g_out)] return [tensor_from_xtensor(g_out)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
if new_x.ndim != old_x.ndim: if new_x.ndim != old_x.ndim:
if new_dim is None:
raise NotImplementedError( raise NotImplementedError(
f"Vectorization of {self} with batched inputs not implemented, " f"Vectorization of {self} cannot infer the new dimension labels. "
"as it can't infer new dimension labels" "Use pytensor.xtensor.vectorization.vectorize_graph instead."
) )
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
else:
return [self(new_x)] return [self(new_x)]
...@@ -111,7 +119,7 @@ class Rename(XTypeCastOp): ...@@ -111,7 +119,7 @@ class Rename(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [rename(g_out, dims=x.type.dims)] return [rename(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True)) old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
......
...@@ -197,7 +197,7 @@ class Index(XOp): ...@@ -197,7 +197,7 @@ class Index(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output]) return Apply(self, [x, *idxs], [output])
def vectorize_node(self, node, new_x, *new_idxs): def vectorize_node(self, node, new_x, *new_idxs, new_dim):
# new_x may have dims in different order # new_x may have dims in different order
# we pair each pre-existing dim to the respective index # we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None) # with new dims having simply a slice(None)
...@@ -237,7 +237,7 @@ class IndexUpdate(XOp): ...@@ -237,7 +237,7 @@ class IndexUpdate(XOp):
out = x.type() out = x.type()
return Apply(self, [x, y, *idxs], [out]) return Apply(self, [x, y, *idxs], [out])
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
# If y or the indices have new dimensions we need to broadcast_x # If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set( exclude: set[str] = set(
chain.from_iterable( chain.from_iterable(
......
...@@ -46,7 +46,7 @@ class XReduce(XOp): ...@@ -46,7 +46,7 @@ class XReduce(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -120,7 +120,7 @@ class XCumReduce(XOp): ...@@ -120,7 +120,7 @@ class XCumReduce(XOp):
out = x.type() out = x.type()
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
......
...@@ -68,7 +68,7 @@ class Stack(XOp): ...@@ -68,7 +68,7 @@ class Stack(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -149,7 +149,7 @@ class UnStack(XOp): ...@@ -149,7 +149,7 @@ class UnStack(XOp):
) )
return Apply(self, [x, *unstacked_lengths], [output]) return Apply(self, [x, *unstacked_lengths], [output])
def vectorize_node(self, node, new_x, *new_unstacked_length): def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length] new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length): if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError( raise NotImplementedError(
...@@ -200,7 +200,7 @@ class Transpose(XOp): ...@@ -200,7 +200,7 @@ class Transpose(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
old_dims = self.dims old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims) new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)] return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
...@@ -318,7 +318,7 @@ class Concat(XOp): ...@@ -318,7 +318,7 @@ class Concat(XOp):
output = xtensor(dtype=dtype, dims=dims, shape=shape) output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
return [self(*new_inputs)] return [self(*new_inputs)]
...@@ -402,7 +402,7 @@ class Squeeze(XOp): ...@@ -402,7 +402,7 @@ class Squeeze(XOp):
) )
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -464,7 +464,7 @@ class ExpandDims(XOp): ...@@ -464,7 +464,7 @@ class ExpandDims(XOp):
) )
return Apply(self, [x, size], [out]) return Apply(self, [x, size], [out])
def vectorize_node(self, node, new_x, new_size): def vectorize_node(self, node, new_x, new_size, new_dim):
new_size = new_size.squeeze() new_size = new_size.squeeze()
if new_size.type.ndim != 0: if new_size.type.ndim != 0:
raise NotImplementedError( raise NotImplementedError(
...@@ -567,7 +567,7 @@ class Broadcast(XOp): ...@@ -567,7 +567,7 @@ class Broadcast(XOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
if exclude_set := set(self.exclude): if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True): for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := ( if invalid_excluded := (
......
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from functools import singledispatch
from itertools import chain from itertools import chain
from typing import Literal
from typing import cast as typing_cast
import numpy as np import numpy as np
...@@ -8,18 +11,22 @@ from pytensor import shared ...@@ -8,18 +11,22 @@ from pytensor import shared
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.traversal import toposort, truncated_graph_inputs
from pytensor.graph.type import HasShape
from pytensor.scalar import discrete_dtypes from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor from pytensor.tensor import (
TensorVariable,
broadcast_shape,
broadcast_to,
tensor,
)
from pytensor.tensor.random.op import RNGConsumerOp from pytensor.tensor.random.op import RNGConsumerOp
from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import ( from pytensor.tensor.utils import (
get_static_shape_from_size_variables, get_static_shape_from_size_variables,
) )
from pytensor.utils import unzip from pytensor.utils import unzip
from pytensor.xtensor.basic import ( from pytensor.xtensor.basic import XOp, XTypeCastOp
XOp,
XTypeCastOp,
)
from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor
...@@ -79,7 +86,7 @@ class XElemwise(XOp): ...@@ -79,7 +86,7 @@ class XElemwise(XOp):
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
return self(*new_inputs, return_list=True) return self(*new_inputs, return_list=True)
...@@ -149,7 +156,7 @@ class XBlockwise(XOp): ...@@ -149,7 +156,7 @@ class XBlockwise(XOp):
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
return self(*new_inputs, return_list=True) return self(*new_inputs, return_list=True)
...@@ -300,7 +307,7 @@ class XRV(XOp, RNGConsumerOp): ...@@ -300,7 +307,7 @@ class XRV(XOp, RNGConsumerOp):
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out]) return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
new_rng, *new_extra_dim_lengths_and_params = new_inputs new_rng, *new_extra_dim_lengths_and_params = new_inputs
k = len(self.extra_dims) k = len(self.extra_dims)
new_extra_dim_lengths, new_params = ( new_extra_dim_lengths, new_params = (
...@@ -319,24 +326,36 @@ class XRV(XOp, RNGConsumerOp): ...@@ -319,24 +326,36 @@ class XRV(XOp, RNGConsumerOp):
@_vectorize_node.register(XOp) @_vectorize_node.register(XOp)
@_vectorize_node.register(XTypeCastOp) @_vectorize_node.register(XTypeCastOp)
def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]: def vectorize_xop(op, node, *new_inputs) -> Sequence[Variable]:
old_inp_dims = [ # This gets called by regular graph_replace, which isn't aware of xtensor and doesn't have a concept of `new_dim`
inp.dims for inp in node.inputs if isinstance(inp.type, XTensorType) return _vectorize_xnode(node.op, node, *new_inputs, new_dim=None)
]
old_out_dims = [
out.dims for out in node.outputs if isinstance(out.type, XTensorType) @singledispatch
] def _vectorize_xnode(
all_old_dims_set = set(chain.from_iterable((*old_inp_dims, old_out_dims))) op: XOp | XTypeCastOp,
node: Apply,
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True): *batched_inputs: Variable,
new_dim: str | None = None,
) -> Sequence[Variable]:
"""Returns vectorized version of node with new batched inputs."""
all_old_dims_set = set(
chain.from_iterable(
x.type.dims
for x in (*node.inputs, *node.outputs)
if isinstance(x.type, XTensorType)
)
)
for new_inp, old_inp in zip(batched_inputs, node.inputs, strict=True):
if not ( if not (
isinstance(new_inp.type, XTensorType) isinstance(new_inp.type, XTensorType)
and isinstance(old_inp.type, XTensorType) and isinstance(old_inp.type, XTensorType)
): ):
continue continue
old_dims_set = set(old_inp.dims) old_dims_set = set(old_inp.type.dims)
new_dims_set = set(new_inp.dims) new_dims_set = set(new_inp.type.dims)
# Validate that new inputs didn't drop pre-existing dims # Validate that new inputs didn't drop pre-existing dims
if missing_dims := old_dims_set - new_dims_set: if missing_dims := old_dims_set - new_dims_set:
...@@ -349,4 +368,297 @@ def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]: ...@@ -349,4 +368,297 @@ def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]:
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}" f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
) )
return op.vectorize_node(node, *new_inputs) return op.vectorize_node(node, *batched_inputs, new_dim=new_dim)
def _vectorize_single_dim(outputs, replace, new_dim: str):
inputs = truncated_graph_inputs(outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]
vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in toposort(outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
if isinstance(node.op, XOp | XTypeCastOp):
node_vect_outs = _vectorize_xnode(
node.op, node, *vect_inputs, new_dim=new_dim
)
else:
node_vect_outs_or_apply = _vectorize_node(node.op, node, *vect_inputs)
# Old API compatibility
node_vect_outs = (
node_vect_outs_or_apply.outputs
if isinstance(node_vect_outs_or_apply, Apply)
else node_vect_outs_or_apply
)
for output, vect_output in zip(node.outputs, node_vect_outs, strict=True):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
# We make sure we don't overwrite the provided replacement with the newly vectorized output
continue
vect_vars[output] = vect_output
return [vect_vars[out] for out in outputs]
def vectorize_graph(
outputs: Variable | Sequence[Variable],
replace: Mapping[Variable, Variable],
*,
new_tensor_dims: Sequence[str] = (),
):
"""Dimension-aware vectorize_graph.
This is an extension to :func:`pytensor.graph.replace.vectorize_graph` that correctly handles
mixed XTensor/TensorVariable graphs.
Vectorization rule for batch TensorVariables works like regular ``vectorize_graph``,
with batched axes assumed to be aligned positionally and present on the left of the new inputs.
They must be given labels with ``new_tensor_dims`` argument (left to right),
for correct interaction with XTensorVariables (and even if there are no XTensorVariables in the graph).
Batched XTensorVariables may contain new dimensions anywhere.
These can include dimensions in ``new_tensor_dims``, as well as other new dimensions
implied by the variable's ``dims``. New dimensions for a given input should not have
existed in the original graph.
The vectorized outputs will have the new dimensions on the left.
The order of new dimensions is:
1. New dimensions introduced by XTensorVariables (that are not in ``new_tensor_dims``).
2. Dimensions specified in ``new_tensor_dims``.
Parameters
----------
outputs: Variable or Sequence of Variable
The output variable(s) of the graph to be vectorized.
replace: Mapping of Variable to Variable
A dictionary mapping original variables to their vectorized counterparts.
new_tensor_dims: Sequence of str, optional
A sequence of string labels for the new batch dimensions introduced by ``TensorVariable``
replacements. These dimensions correspond to the leading axes of the new tensor variables.
This argument is required if any ``TensorVariable`` replacements introduce new dimensions.
Returns
-------
vectorized_outputs: Variable or Sequence of Variable
Vectorized output variable(s).
Examples
--------
Vectorize a graph with XTensor variables:
.. testcode:: python
from pytensor.xtensor import xtensor
from pytensor.xtensor.vectorization import vectorize_graph
x = xtensor("x", dims=("a",))
y = xtensor("y", dims=("a",))
out = x + y
# We want to vectorize over new dimensions "c" and "b"
# For XTensor, new dimensions can be anywhere
x_new = xtensor("x_new", dims=("c", "a"))
y_new = xtensor("y_new", dims=("a", "b"))
out_vec = vectorize_graph(out, {x: x_new, y: y_new})
# Output batch dimensions are always on the left
assert out_vec.type.dims == ("c", "b", "a")
Vectorize a graph with standard Tensor variables:
.. testcode:: python
from pytensor.tensor import tensor, TensorVariable
from pytensor.xtensor.vectorization import vectorize_graph
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
out = x + y
# We vectorize over new dimension of "a", and "b".
# These must be on the left and broadcast correctly
x_new = tensor("x_new", shape=(5, 3))
y_new = tensor("y_new", shape=(7, 1, 3))
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["a", "b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.type.shape == (7, 5, 3)
Vectorize a mixed graph:
.. testcode:: python
from pytensor.tensor import tensor
from pytensor.xtensor import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
x = xtensor("x", shape=(5,), dims=("a",))
y = tensor("y", shape=(5,))
out = x + as_xtensor(y, dims=("a",))
# Vectorize over a new dimension "c"
x_new = xtensor("x_new", dims=("a", "c"), shape=(5, 3))
y_new = tensor("y_new", shape=(3, 5)) # Leading dim corresponds to "c" (size 3)
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["c"])
assert out_vec.type.dims == ("c", "a")
# Treat the new dimension of y_new as being "b" (size 3)
# x_new introduces "c" (size 3)
# Result has XTensor-only new dims first ("c"), then new_tensor_dims ("b")
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["b"])
assert out_vec.type.dims == ("c", "b", "a")
"""
seq_outputs = outputs if isinstance(outputs, Sequence) else (outputs,)
if not all(
isinstance(key, Variable) and isinstance(value, Variable)
for key, value in replace.items()
):
raise ValueError(f"Some of the replaced items are not Variables: {replace}")
# Collect new dimensions and sizes, and validate
new_xtensor_sizes: dict[str, TensorVariable] = {}
new_tensor_dim_lengths: list[tuple[TensorVariable | Literal[1], ...]] = []
for old, new in replace.items():
if isinstance(new, XTensorVariable):
old_var_dims_set = set(old.type.dims)
new_var_dims_set = set(new.type.dims)
if missing_dims := old_var_dims_set - new_var_dims_set:
raise ValueError(
f"Vectorized input {new} is missing pre-existing dims: {sorted(missing_dims)}"
)
new_xtensor_sizes.update(
{d: s for d, s in new.sizes.items() if d not in old_var_dims_set}
)
elif isinstance(new, TensorVariable):
n_new_dims = new.type.ndim - old.type.ndim
if n_new_dims < 0:
raise ValueError(
f"Vectorized input {new} is missing pre-existing dims {new.type.ndim=}, {old.type.ndim=}"
)
if n_new_dims > len(new_tensor_dims):
if not new_tensor_dims:
raise ValueError(
f"TensorVariable replacement {new} has {n_new_dims} batch dimensions. "
f"You must specify `new_tensor_dims` to label these."
)
else:
raise ValueError(
f"TensorVariable replacement {new} has {n_new_dims} batch dimensions "
f"but only {new_tensor_dims=} were specified. "
)
new_tensor_dim_lengths.append(
tuple(
1 if b else s
for s, b in zip(
tuple(new.shape)[:n_new_dims],
new.type.broadcastable[:n_new_dims],
)
)
)
elif isinstance(new.type, HasShape) and new.type.ndim != old.type.ndim:
raise NotImplementedError(
f"vectorize_graph does not know how to handle batched input {new} of type {new.type}"
)
# Align xtensor batch dimensions on the left, and broadcast tensor batch dimensions
new_dims = (
*(dim for dim in new_xtensor_sizes if dim not in new_tensor_dims),
*new_tensor_dims,
)
# Create a mapping from new_tensor_dims -> broadcasted shape from tensors
new_tensor_sizes: dict[str, Variable] = {}
if new_tensor_dims:
new_tensor_bcast_dim_lengths = broadcast_shape(
*new_tensor_dim_lengths, arrays_are_shapes=True
)
del new_tensor_dim_lengths
if len(new_tensor_bcast_dim_lengths) != len(new_tensor_dims):
raise ValueError(
f"{len(new_tensor_dims)} tensor dims were specified, but only {len(new_tensor_bcast_dim_lengths)} were found in the new inputs"
)
new_tensor_sizes = dict(zip(new_tensor_dims, new_tensor_bcast_dim_lengths))
# Give preference to tensor sizes to avoid unnecessary broadcasting (Alloc)
# XTensor sizes are implicitly handled by transpose and dim names, so they don't need strict size equality
new_sizes = tuple(
new_xtensor_sizes.get(dim, new_tensor_sizes.get(dim, 1)) for dim in new_dims
)
# Align batch dimensions on the left (*xtensor_unique_batch_dims, *tensor_batch_dims, ...)
# We broadcast tensor batch dims as they may have been length 1
aligned_replace = {}
for old, new in replace.items():
if isinstance(new, XTensorVariable):
new = new.transpose(*new_dims, ..., missing_dims="ignore")
elif isinstance(new, TensorVariable):
n_existing_batch_dims = new.type.ndim - old.type.ndim
if n_existing_batch_dims < len(new_dims) or any(
new.type.broadcastable[: len(new_dims)]
):
new = broadcast_to(
new,
shape=(*new_sizes, *tuple(new.shape)[n_existing_batch_dims:]),
)
aligned_replace[old] = new
del replace
seq_vect_outputs = seq_outputs
remaining_new_dims = list(new_dims)
while remaining_new_dims:
new_dim = remaining_new_dims.pop()
if remaining_new_dims:
# We need to use a dummy inputs to batch graph once at a time
# We drop all the dims that are still in `remaining_new_dims`
# Create a mapping: original -> intermediate_batched
single_dim_replace = {}
for old, new in aligned_replace.items():
n_remaining_dims = len(remaining_new_dims)
if isinstance(new, XTensorVariable):
intermediate_dims, intermediate_shape = unzip(
(
(d, s)
for d, s in zip(new.type.dims, new.type.shape)
if d not in remaining_new_dims
),
n=2,
)
intermediate_type = new.type.clone(
dims=intermediate_dims, shape=intermediate_shape
)
elif isinstance(new, TensorVariable):
intermediate_type = new.type.clone(
shape=new.type.shape[n_remaining_dims:]
)
else:
intermediate_type = new.type
single_dim_replace[old] = intermediate_type()
# Updated aligned replace mapping: intermediate_batched -> final_batched
aligned_replace = dict(
zip(single_dim_replace.values(), aligned_replace.values())
)
else:
single_dim_replace = aligned_replace
seq_vect_outputs = _vectorize_single_dim(
seq_vect_outputs, single_dim_replace, new_dim
)
aligned_seq_vect_outputs = [
new.transpose(*new_dims, *typing_cast(XTensorVariable, old).dims)
if isinstance(new, XTensorVariable)
else new
for new, old in zip(seq_vect_outputs, seq_outputs)
]
return (
aligned_seq_vect_outputs
if isinstance(outputs, Sequence)
else aligned_seq_vect_outputs[0]
)
...@@ -3,10 +3,12 @@ import pytest ...@@ -3,10 +3,12 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
import re
import numpy as np import numpy as np
from pytensor import function from pytensor import function
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph as tensor_vectorize_graph
from pytensor.tensor import matrix, vector from pytensor.tensor import matrix, vector
from pytensor.xtensor.basic import ( from pytensor.xtensor.basic import (
Rename, Rename,
...@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import ( ...@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import (
tensor_from_xtensor, tensor_from_xtensor,
xtensor_from_tensor, xtensor_from_tensor,
) )
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations from tests.unittest_tools import assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import check_vectorization from tests.xtensor.util import check_vectorization
...@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize(): ...@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize():
t_batched = matrix("t_batched") t_batched = matrix("t_batched")
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError,
match=re.escape(
"cannot infer the new dimension labels. Use pytensor.xtensor.vectorization.vectorize_graph instead."
),
): ):
vectorize_graph([x], {t: t_batched}) tensor_vectorize_graph(x, {t: t_batched})
vec_x = vectorize_graph(x, {t: t_batched}, new_tensor_dims=("b",))
assert_equal_computations([vec_x], [as_xtensor(t_batched, dims=("b", "a"))])
def test_tensor_from_xtensor_vectorize(): def test_tensor_from_xtensor_vectorize():
...@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize(): ...@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5)) x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5))
y_batched = vectorize_graph(y, {x: x_batched}) y_batched = tensor_vectorize_graph(y, {x: x_batched})
# vectorize_graph should place output batch dimension on the left # vectorize_graph should place output batch dimension on the left
assert y_batched.type.shape == (5, 3) assert y_batched.type.shape == (5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values]) assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values])
...@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize(): ...@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5)) x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5))
# vectorize_graph can't handle multiple batch dimensions safely # vectorize_graph can't handle multiple batch dimensions safely
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
vectorize_graph(y, {x: x_batched}) tensor_vectorize_graph(y, {x: x_batched})
# xtensor vectorize_graph can handle this graph safely
y_batched = vectorize_graph(y, {x: x_batched})
assert y_batched.type.shape == (7, 5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("c", "b", "a").values])
...@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like ...@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like from xarray import zeros_like as xr_zeros_like
from pytensor.graph import vectorize_graph
from pytensor.tensor import scalar, vector from pytensor.tensor import scalar, vector
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
broadcast, broadcast,
...@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import ( ...@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import (
zeros_like, zeros_like,
) )
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import ( from tests.xtensor.util import (
check_vectorization, check_vectorization,
xr_arange_like, xr_arange_like,
...@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize(): ...@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize():
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError, match=r"Vectorization of .* not implemented"
): ):
vectorize_graph([y], {x: x_batch, l: l_batch}) vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
def test_unstack_batch_length_vectorize(): def test_unstack_batch_length_vectorize():
...@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize(): ...@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize():
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError, match=r"Vectorization of .* not implemented"
): ):
vectorize_graph([y], {x: x_batch, l: l_batch}) vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
import numpy as np
import pytest
from pytensor.tensor import TensorVariable, broadcast_to, tensor
from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations
class TestVectorizeGraph:
def test_pure_xtensor_graph(self):
x = xtensor("x", dims=("a",))
out = x + 1
x_new = xtensor("x_new", dims=("c", "a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = x_new.transpose("c", "b", "a") + 1
assert_equal_computations([out_vec], [expected])
def test_pure_tensor_graph(self):
x = tensor("x", shape=())
out = x + 1
x_new = tensor("x_new", shape=(5,))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 1
expected = x_new + 1
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_graph(self):
x = xtensor("x", dims=("a",))
t = x.values # Convert to TensorVariable
t2 = t + np.ones(1)
out = xtensor_from_tensor(t2, dims=("a",))
x_new = xtensor("x_new", dims=("a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("b", "a")
expected = as_xtensor(
x_new.transpose("b", "a").values + np.ones(1), dims=("b", "a")
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_multiple_inputs_graph(self):
x = xtensor("x", dims=("a",))
y = xtensor("y", dims=("a",))
t = x.values + y.values
out = xtensor_from_tensor(t, dims=("a",))
x_new = xtensor("x_new", dims=("a", "c"))
# Both inputs have the same batch dims
y_new = xtensor("y_new", dims=("c", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "a")
expected = as_xtensor(
(x_new.transpose("c", "a").values + y_new.transpose("c", "a").values),
dims=("c", "a"),
)
assert_equal_computations([out_vec], [expected])
# Inputs have different batch dims
y_new = xtensor("y_new", dims=("b", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = as_xtensor(
(
x_new.transpose("c", "a").values[:, None]
+ y_new.transpose("b", "a").values[None, :]
),
dims=("c", "b", "a"),
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_xtensor_graph(self):
x = tensor("x", shape=(3,))
t = as_xtensor(x, dims=("a",))
t2 = t + 1
out = t2.values
x_new = tensor("x_new", shape=(5, 3))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 2
expected = (as_xtensor(x_new, dims=("b", "a")) + 1).values
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs(self):
x = xtensor("x", dims=("a",), shape=(3,))
y = tensor("y", shape=(5,))
out = as_xtensor(y[2:], dims=("b",)) + x
x_new = xtensor("x_new", dims=("a", "d"), shape=(3, 7))
y_new = tensor("y_new", shape=(7, 5))
# New dimension of y is aligned with the new dimension of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["d"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "b", "a")
expected = as_xtensor(y_new[:, 2:], dims=("d", "b")) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
# New dimension of y is distinct from that of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["c"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "c", "b", "a")
# x introduced a new dimension "d" which causes y to be broadcasted
y_broadcasted = broadcast_to(
y_new, (x_new.sizes["d"], y_new.shape[0], y_new.shape[1])
)
expected = as_xtensor(
y_broadcasted[:, :, 2:], dims=("d", "c", "b")
) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs_complex_broadcasting(self):
a = xtensor("a", dims=("a",), shape=(3,))
b = xtensor("b", dims=("b"), shape=(5,))
y = tensor("y", shape=(7,))
z = tensor("z", shape=(11,))
out = a + b + y.sum() + z.sum()
assert out.dims == ("a", "b")
a_new = xtensor("a_new", dims=("a*", "a"), shape=(33, 3))
b_new = xtensor("b_new", dims=("b*", "b"), shape=(55, 5))
y_new = tensor("y_new", shape=(1, 55, 2, 1, 7))
z_new = tensor("z_new", shape=(33, 1, 1, 2, 11))
[out_vec] = vectorize_graph(
[out],
{a: a_new, b: b_new, y: y_new, z: z_new},
new_tensor_dims=["a*", "b*", "y*", "z*"],
)
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("a*", "b*", "y*", "z*", "a", "b")
batch_shape_truth = (
a_new.sizes["a*"],
b_new.sizes["b*"],
y_new.shape[2],
z_new.shape[3],
)
y_new_bcast = broadcast_to(y_new, (*batch_shape_truth, y_new.shape[4]))
z_new_bcast = broadcast_to(z_new, (*batch_shape_truth, z_new.shape[4]))
expected_out = (
(a_new + b_new)
+ as_xtensor(y_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
+ as_xtensor(z_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
).transpose("a*", "b*", "y*", "z*", ...)
assert_equal_computations([out_vec], [expected_out])
def test_invalid_cases(self):
x = xtensor("x", dims=("a",))
out = x + 1
# Missing xtensor dims
x_bad = xtensor("x_bad", dims=("b",)) # Missing "a"
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out], {x: x_bad})
# New xtensor dims that were present in original graph
y = xtensor("y", dims=("b",))
out2 = x + y
x_new_conflict = xtensor("x_new", dims=("a", "b"))
# "b" is new to x, but present in graph (in y)
with pytest.raises(ValueError, match="new dimensions that were present"):
vectorize_graph([out2], {x: x_new_conflict})
# Missing tensor dims
t = tensor("t", shape=(3,))
out_t = t + 1
# Replacement has fewer dims (rank 0)
t_bad_rank = tensor("t_bad", shape=())
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out_t], {t: t_bad_rank})
# Missing new_tensor_dims
t_new = tensor("t_new", shape=(5, 5, 3))
with pytest.raises(ValueError, match="You must specify `new_tensor_dims`"):
vectorize_graph([out_t], {t: t_new})
with pytest.raises(ValueError, match=r"but only .* were specified"):
vectorize_graph([out_t], {t: t_new}, new_tensor_dims=["a"])
# Excess new_tensor_dims
# Replacement adds 1 dim, but 2 are specified
t_new_1dim = tensor("t_new_1dim", shape=(5, 3))
with pytest.raises(ValueError, match="tensor dims were specified, but only"):
vectorize_graph([out_t], {t: t_new_1dim}, new_tensor_dims=["a", "b"])
...@@ -10,8 +10,8 @@ from xarray import DataArray ...@@ -10,8 +10,8 @@ from xarray import DataArray
from xarray.testing import assert_allclose from xarray.testing import assert_allclose
from pytensor import function from pytensor import function
from pytensor.graph import vectorize_graph
from pytensor.xtensor.type import XTensorType, as_xtensor from pytensor.xtensor.type import XTensorType, as_xtensor
from pytensor.xtensor.vectorization import vectorize_graph
def xr_function(*args, **kwargs): def xr_function(*args, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论