提交 87470065 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement dim-aware vectorize_graph

上级 9d99267c
...@@ -283,6 +283,13 @@ def vectorize_graph( ...@@ -283,6 +283,13 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])] # [array([-10., -11.]), array([10., 11.])]
""" """
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if isinstance(outputs, Sequence): if isinstance(outputs, Sequence):
seq_outputs = outputs seq_outputs = outputs
else: else:
......
import warnings import warnings
from collections.abc import Collection, Iterable from collections.abc import Collection, Iterable, Sequence
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
...@@ -1926,7 +1926,7 @@ def logspace( ...@@ -1926,7 +1926,7 @@ def logspace(
def broadcast_to( def broadcast_to(
x: TensorVariable, shape: TensorVariable | tuple[Variable, ...] x: TensorLike, shape: TensorLike | Sequence[TensorLike]
) -> TensorVariable: ) -> TensorVariable:
"""Broadcast an array to a new shape. """Broadcast an array to a new shape.
......
...@@ -18,7 +18,9 @@ class XOp(Op): ...@@ -18,7 +18,9 @@ class XOp(Op):
def do_constant_folding(self, fgraph, node): def do_constant_folding(self, fgraph, node):
return False return False
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]: def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}") raise NotImplementedError(f"Vectorized node not implemented for {self}")
...@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp): ...@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return input_shapes return input_shapes
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]: def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}") raise NotImplementedError(f"Vectorized node not implemented for {self}")
...@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp): ...@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)] return [xtensor_from_tensor(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1: if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError( raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. " f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time." "You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
) )
new_x = new_x.transpose(..., *old_x.dims) new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)] return [self(new_x)]
...@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp): ...@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [tensor_from_xtensor(g_out)] return [tensor_from_xtensor(g_out)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
if new_x.ndim != old_x.ndim: if new_x.ndim != old_x.ndim:
if new_dim is None:
raise NotImplementedError( raise NotImplementedError(
f"Vectorization of {self} with batched inputs not implemented, " f"Vectorization of {self} cannot infer the new dimension labels. "
"as it can't infer new dimension labels" "Use pytensor.xtensor.vectorization.vectorize_graph instead."
) )
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
else:
return [self(new_x)] return [self(new_x)]
...@@ -111,7 +119,7 @@ class Rename(XTypeCastOp): ...@@ -111,7 +119,7 @@ class Rename(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [rename(g_out, dims=x.type.dims)] return [rename(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs [old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True)) old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
......
...@@ -197,7 +197,7 @@ class Index(XOp): ...@@ -197,7 +197,7 @@ class Index(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output]) return Apply(self, [x, *idxs], [output])
def vectorize_node(self, node, new_x, *new_idxs): def vectorize_node(self, node, new_x, *new_idxs, new_dim):
# new_x may have dims in different order # new_x may have dims in different order
# we pair each pre-existing dim to the respective index # we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None) # with new dims having simply a slice(None)
...@@ -237,7 +237,7 @@ class IndexUpdate(XOp): ...@@ -237,7 +237,7 @@ class IndexUpdate(XOp):
out = x.type() out = x.type()
return Apply(self, [x, y, *idxs], [out]) return Apply(self, [x, y, *idxs], [out])
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
# If y or the indices have new dimensions we need to broadcast_x # If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set( exclude: set[str] = set(
chain.from_iterable( chain.from_iterable(
......
...@@ -46,7 +46,7 @@ class XReduce(XOp): ...@@ -46,7 +46,7 @@ class XReduce(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -120,7 +120,7 @@ class XCumReduce(XOp): ...@@ -120,7 +120,7 @@ class XCumReduce(XOp):
out = x.type() out = x.type()
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
......
...@@ -68,7 +68,7 @@ class Stack(XOp): ...@@ -68,7 +68,7 @@ class Stack(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -149,7 +149,7 @@ class UnStack(XOp): ...@@ -149,7 +149,7 @@ class UnStack(XOp):
) )
return Apply(self, [x, *unstacked_lengths], [output]) return Apply(self, [x, *unstacked_lengths], [output])
def vectorize_node(self, node, new_x, *new_unstacked_length): def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length] new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length): if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError( raise NotImplementedError(
...@@ -200,7 +200,7 @@ class Transpose(XOp): ...@@ -200,7 +200,7 @@ class Transpose(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
old_dims = self.dims old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims) new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)] return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
...@@ -318,7 +318,7 @@ class Concat(XOp): ...@@ -318,7 +318,7 @@ class Concat(XOp):
output = xtensor(dtype=dtype, dims=dims, shape=shape) output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
return [self(*new_inputs)] return [self(*new_inputs)]
...@@ -402,7 +402,7 @@ class Squeeze(XOp): ...@@ -402,7 +402,7 @@ class Squeeze(XOp):
) )
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x): def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)] return [self(new_x)]
...@@ -464,7 +464,7 @@ class ExpandDims(XOp): ...@@ -464,7 +464,7 @@ class ExpandDims(XOp):
) )
return Apply(self, [x, size], [out]) return Apply(self, [x, size], [out])
def vectorize_node(self, node, new_x, new_size): def vectorize_node(self, node, new_x, new_size, new_dim):
new_size = new_size.squeeze() new_size = new_size.squeeze()
if new_size.type.ndim != 0: if new_size.type.ndim != 0:
raise NotImplementedError( raise NotImplementedError(
...@@ -567,7 +567,7 @@ class Broadcast(XOp): ...@@ -567,7 +567,7 @@ class Broadcast(XOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs): def vectorize_node(self, node, *new_inputs, new_dim):
if exclude_set := set(self.exclude): if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True): for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := ( if invalid_excluded := (
......
...@@ -3,10 +3,12 @@ import pytest ...@@ -3,10 +3,12 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
import re
import numpy as np import numpy as np
from pytensor import function from pytensor import function
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph as tensor_vectorize_graph
from pytensor.tensor import matrix, vector from pytensor.tensor import matrix, vector
from pytensor.xtensor.basic import ( from pytensor.xtensor.basic import (
Rename, Rename,
...@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import ( ...@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import (
tensor_from_xtensor, tensor_from_xtensor,
xtensor_from_tensor, xtensor_from_tensor,
) )
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations from tests.unittest_tools import assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import check_vectorization from tests.xtensor.util import check_vectorization
...@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize(): ...@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize():
t_batched = matrix("t_batched") t_batched = matrix("t_batched")
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError,
match=re.escape(
"cannot infer the new dimension labels. Use pytensor.xtensor.vectorization.vectorize_graph instead."
),
): ):
vectorize_graph([x], {t: t_batched}) tensor_vectorize_graph(x, {t: t_batched})
vec_x = vectorize_graph(x, {t: t_batched}, new_tensor_dims=("b",))
assert_equal_computations([vec_x], [as_xtensor(t_batched, dims=("b", "a"))])
def test_tensor_from_xtensor_vectorize(): def test_tensor_from_xtensor_vectorize():
...@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize(): ...@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5)) x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5))
y_batched = vectorize_graph(y, {x: x_batched}) y_batched = tensor_vectorize_graph(y, {x: x_batched})
# vectorize_graph should place output batch dimension on the left # vectorize_graph should place output batch dimension on the left
assert y_batched.type.shape == (5, 3) assert y_batched.type.shape == (5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values]) assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values])
...@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize(): ...@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize():
x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5)) x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5))
# vectorize_graph can't handle multiple batch dimensions safely # vectorize_graph can't handle multiple batch dimensions safely
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
vectorize_graph(y, {x: x_batched}) tensor_vectorize_graph(y, {x: x_batched})
# xtensor vectorize_graph can handle this graph safely
y_batched = vectorize_graph(y, {x: x_batched})
assert y_batched.type.shape == (7, 5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("c", "b", "a").values])
...@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like ...@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like from xarray import zeros_like as xr_zeros_like
from pytensor.graph import vectorize_graph
from pytensor.tensor import scalar, vector from pytensor.tensor import scalar, vector
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
broadcast, broadcast,
...@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import ( ...@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import (
zeros_like, zeros_like,
) )
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import ( from tests.xtensor.util import (
check_vectorization, check_vectorization,
xr_arange_like, xr_arange_like,
...@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize(): ...@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize():
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError, match=r"Vectorization of .* not implemented"
): ):
vectorize_graph([y], {x: x_batch, l: l_batch}) vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
def test_unstack_batch_length_vectorize(): def test_unstack_batch_length_vectorize():
...@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize(): ...@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize():
with pytest.raises( with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented" NotImplementedError, match=r"Vectorization of .* not implemented"
): ):
vectorize_graph([y], {x: x_batch, l: l_batch}) vectorize_graph([y], {x: x_batch, l: l_batch}, new_tensor_dims=["batch"])
import numpy as np
import pytest
from pytensor.tensor import TensorVariable, broadcast_to, tensor
from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.unittest_tools import assert_equal_computations
class TestVectorizeGraph:
def test_pure_xtensor_graph(self):
x = xtensor("x", dims=("a",))
out = x + 1
x_new = xtensor("x_new", dims=("c", "a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = x_new.transpose("c", "b", "a") + 1
assert_equal_computations([out_vec], [expected])
def test_pure_tensor_graph(self):
x = tensor("x", shape=())
out = x + 1
x_new = tensor("x_new", shape=(5,))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 1
expected = x_new + 1
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_graph(self):
x = xtensor("x", dims=("a",))
t = x.values # Convert to TensorVariable
t2 = t + np.ones(1)
out = xtensor_from_tensor(t2, dims=("a",))
x_new = xtensor("x_new", dims=("a", "b"))
[out_vec] = vectorize_graph([out], {x: x_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("b", "a")
expected = as_xtensor(
x_new.transpose("b", "a").values + np.ones(1), dims=("b", "a")
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_tensor_multiple_inputs_graph(self):
x = xtensor("x", dims=("a",))
y = xtensor("y", dims=("a",))
t = x.values + y.values
out = xtensor_from_tensor(t, dims=("a",))
x_new = xtensor("x_new", dims=("a", "c"))
# Both inputs have the same batch dims
y_new = xtensor("y_new", dims=("c", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "a")
expected = as_xtensor(
(x_new.transpose("c", "a").values + y_new.transpose("c", "a").values),
dims=("c", "a"),
)
assert_equal_computations([out_vec], [expected])
# Inputs have different batch dims
y_new = xtensor("y_new", dims=("b", "a"))
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new})
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("c", "b", "a")
expected = as_xtensor(
(
x_new.transpose("c", "a").values[:, None]
+ y_new.transpose("b", "a").values[None, :]
),
dims=("c", "b", "a"),
)
assert_equal_computations([out_vec], [expected])
def test_intermediate_xtensor_graph(self):
x = tensor("x", shape=(3,))
t = as_xtensor(x, dims=("a",))
t2 = t + 1
out = t2.values
x_new = tensor("x_new", shape=(5, 3))
[out_vec] = vectorize_graph([out], {x: x_new}, new_tensor_dims=["b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.ndim == 2
expected = (as_xtensor(x_new, dims=("b", "a")) + 1).values
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs(self):
x = xtensor("x", dims=("a",), shape=(3,))
y = tensor("y", shape=(5,))
out = as_xtensor(y[2:], dims=("b",)) + x
x_new = xtensor("x_new", dims=("a", "d"), shape=(3, 7))
y_new = tensor("y_new", shape=(7, 5))
# New dimension of y is aligned with the new dimension of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["d"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "b", "a")
expected = as_xtensor(y_new[:, 2:], dims=("d", "b")) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
# New dimension of y is distinct from that of x
[out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["c"])
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("d", "c", "b", "a")
# x introduced a new dimension "d" which causes y to be broadcasted
y_broadcasted = broadcast_to(
y_new, (x_new.sizes["d"], y_new.shape[0], y_new.shape[1])
)
expected = as_xtensor(
y_broadcasted[:, :, 2:], dims=("d", "c", "b")
) + x_new.transpose("d", "a")
assert_equal_computations([out_vec], [expected])
def test_mixed_type_inputs_complex_broadcasting(self):
a = xtensor("a", dims=("a",), shape=(3,))
b = xtensor("b", dims=("b"), shape=(5,))
y = tensor("y", shape=(7,))
z = tensor("z", shape=(11,))
out = a + b + y.sum() + z.sum()
assert out.dims == ("a", "b")
a_new = xtensor("a_new", dims=("a*", "a"), shape=(33, 3))
b_new = xtensor("b_new", dims=("b*", "b"), shape=(55, 5))
y_new = tensor("y_new", shape=(1, 55, 2, 1, 7))
z_new = tensor("z_new", shape=(33, 1, 1, 2, 11))
[out_vec] = vectorize_graph(
[out],
{a: a_new, b: b_new, y: y_new, z: z_new},
new_tensor_dims=["a*", "b*", "y*", "z*"],
)
assert isinstance(out_vec.type, XTensorType)
assert out_vec.type.dims == ("a*", "b*", "y*", "z*", "a", "b")
batch_shape_truth = (
a_new.sizes["a*"],
b_new.sizes["b*"],
y_new.shape[2],
z_new.shape[3],
)
y_new_bcast = broadcast_to(y_new, (*batch_shape_truth, y_new.shape[4]))
z_new_bcast = broadcast_to(z_new, (*batch_shape_truth, z_new.shape[4]))
expected_out = (
(a_new + b_new)
+ as_xtensor(y_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
+ as_xtensor(z_new_bcast.sum(axis=-1), dims=("a*", "b*", "y*", "z*"))
).transpose("a*", "b*", "y*", "z*", ...)
assert_equal_computations([out_vec], [expected_out])
def test_invalid_cases(self):
x = xtensor("x", dims=("a",))
out = x + 1
# Missing xtensor dims
x_bad = xtensor("x_bad", dims=("b",)) # Missing "a"
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out], {x: x_bad})
# New xtensor dims that were present in original graph
y = xtensor("y", dims=("b",))
out2 = x + y
x_new_conflict = xtensor("x_new", dims=("a", "b"))
# "b" is new to x, but present in graph (in y)
with pytest.raises(ValueError, match="new dimensions that were present"):
vectorize_graph([out2], {x: x_new_conflict})
# Missing tensor dims
t = tensor("t", shape=(3,))
out_t = t + 1
# Replacement has fewer dims (rank 0)
t_bad_rank = tensor("t_bad", shape=())
with pytest.raises(ValueError, match="missing pre-existing dims"):
vectorize_graph([out_t], {t: t_bad_rank})
# Missing new_tensor_dims
t_new = tensor("t_new", shape=(5, 5, 3))
with pytest.raises(ValueError, match="You must specify `new_tensor_dims`"):
vectorize_graph([out_t], {t: t_new})
with pytest.raises(ValueError, match=r"but only .* were specified"):
vectorize_graph([out_t], {t: t_new}, new_tensor_dims=["a"])
# Excess new_tensor_dims
# Replacement adds 1 dim, but 2 are specified
t_new_1dim = tensor("t_new_1dim", shape=(5, 3))
with pytest.raises(ValueError, match="tensor dims were specified, but only"):
vectorize_graph([out_t], {t: t_new_1dim}, new_tensor_dims=["a", "b"])
...@@ -10,8 +10,8 @@ from xarray import DataArray ...@@ -10,8 +10,8 @@ from xarray import DataArray
from xarray.testing import assert_allclose from xarray.testing import assert_allclose
from pytensor import function from pytensor import function
from pytensor.graph import vectorize_graph
from pytensor.xtensor.type import XTensorType, as_xtensor from pytensor.xtensor.type import XTensorType, as_xtensor
from pytensor.xtensor.vectorization import vectorize_graph
def xr_function(*args, **kwargs): def xr_function(*args, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论