提交 b1678fd2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement join_dims as mirror of split_dims

Mainly, joining 0 axes is equivalent to inserting a new dimension. This is the mirror of how splitting a single axis into an empty shape is equivalent to squeezing it.
上级 c04185db
...@@ -12,7 +12,7 @@ from pytensor.graph.op import Op ...@@ -12,7 +12,7 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split from pytensor.tensor.basic import infer_static_shape, join, split
from pytensor.tensor.math import prod from pytensor.tensor.math import prod
from pytensor.tensor.type import tensor from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -24,10 +24,7 @@ ShapeValueType: TypeAlias = ( ...@@ -24,10 +24,7 @@ ShapeValueType: TypeAlias = (
class JoinDims(Op): class JoinDims(Op):
__props__ = ( __props__ = ("start_axis", "n_axes")
"start_axis",
"n_axes",
)
view_map = {0: [0]} view_map = {0: [0]}
def __init__(self, start_axis: int, n_axes: int): def __init__(self, start_axis: int, n_axes: int):
...@@ -55,6 +52,11 @@ class JoinDims(Op): ...@@ -55,6 +52,11 @@ class JoinDims(Op):
static_shapes = x.type.shape static_shapes = x.type.shape
axis_range = self.axis_range axis_range = self.axis_range
if (self.start_axis + self.n_axes) > x.type.ndim:
raise ValueError(
f"JoinDims was asked to join dimensions {self.start_axis} to {self.n_axes}, "
f"but input {x} has only {x.type.ndim} dimensions."
)
joined_shape = ( joined_shape = (
int(np.prod([static_shapes[i] for i in axis_range])) int(np.prod([static_shapes[i] for i in axis_range]))
...@@ -69,9 +71,7 @@ class JoinDims(Op): ...@@ -69,9 +71,7 @@ class JoinDims(Op):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
[input_shape] = shapes [input_shape] = shapes
axis_range = self.axis_range joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int)
joined_shape = prod([input_shape[i] for i in axis_range])
return [self.output_shapes(input_shape, joined_shape)] return [self.output_shapes(input_shape, joined_shape)]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -98,23 +98,24 @@ class JoinDims(Op): ...@@ -98,23 +98,24 @@ class JoinDims(Op):
@_vectorize_node.register(JoinDims) @_vectorize_node.register(JoinDims)
def _vectorize_joindims(op, node, x): def _vectorize_joindims(op, node, x):
[old_x] = node.inputs [old_x] = node.inputs
batched_ndims = x.type.ndim - old_x.type.ndim batched_ndims = x.type.ndim - old_x.type.ndim
start_axis = op.start_axis return JoinDims(op.start_axis + batched_ndims, op.n_axes).make_node(x)
n_axes = op.n_axes
return JoinDims(start_axis + batched_ndims, n_axes).make_node(x)
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable: def join_dims(
x: TensorLike, start_axis: int = 0, n_axes: int | None = None
) -> TensorVariable:
"""Join consecutive dimensions of a tensor into a single dimension. """Join consecutive dimensions of a tensor into a single dimension.
Parameters Parameters
---------- ----------
x : TensorLike x : TensorLike
The input tensor. The input tensor.
axis : int or sequence of int, optional start_axis : int, default 0
The dimensions to join. If None, all dimensions are joined. The axis from which to start joining dimensions
n_axes: int, optional.
The number of axis to join after `axis`. If `None` joins all remaining axis.
If 0, it inserts a new dimension of length 1.
Returns Returns
------- -------
...@@ -125,33 +126,32 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV ...@@ -125,33 +126,32 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
-------- --------
>>> import pytensor.tensor as pt >>> import pytensor.tensor as pt
>>> x = pt.tensor("x", shape=(2, 3, 4, 5)) >>> x = pt.tensor("x", shape=(2, 3, 4, 5))
>>> y = pt.join_dims(x, axis=(1, 2)) >>> y = pt.join_dims(x)
>>> y.type.shape
(120,)
>>> y = pt.join_dims(x, start_axis=1)
>>> y.type.shape
(2, 60)
>>> y = pt.join_dims(x, start_axis=1, n_axes=2)
>>> y.type.shape >>> y.type.shape
(2, 12, 5) (2, 12, 5)
""" """
x = as_tensor_variable(x) x = as_tensor_variable(x)
ndim = x.type.ndim
if axis is None: if start_axis < 0:
axis = list(range(x.ndim)) # We treat scalars as if they had a single axis
elif isinstance(axis, int): start_axis += max(1, ndim)
axis = [axis]
elif not isinstance(axis, list | tuple):
raise TypeError("axis must be an int, a list/tuple of ints, or None")
axis = normalize_axis_tuple(axis, x.ndim)
if len(axis) <= 1:
return x # type: ignore[unreachable]
if np.diff(axis).max() > 1: if not 0 <= start_axis <= ndim:
raise ValueError( raise IndexError(
f"join_dims axis must be consecutive, got normalized axis: {axis}" f"Axis {start_axis} is out of bounds for array of dimension {ndim}"
) )
start_axis = min(axis) if n_axes is None:
n_axes = len(axis) n_axes = ndim - start_axis
return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value] return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value]
class SplitDims(Op): class SplitDims(Op):
...@@ -213,11 +213,11 @@ class SplitDims(Op): ...@@ -213,11 +213,11 @@ class SplitDims(Op):
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
(x, _) = inputs (x, _) = inputs
(g_out,) = output_grads (g_out,) = output_grads
n_axes = g_out.ndim - x.ndim + 1 n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes)) return [
join_dims(g_out, start_axis=self.axis, n_axes=n_axes),
return [join_dims(g_out, axis=axis_range), disconnected_type()] disconnected_type(),
]
@_vectorize_node.register(SplitDims) @_vectorize_node.register(SplitDims)
...@@ -230,14 +230,13 @@ def _vectorize_splitdims(op, node, x, shape): ...@@ -230,14 +230,13 @@ def _vectorize_splitdims(op, node, x, shape):
if as_tensor_variable(shape).type.ndim != 1: if as_tensor_variable(shape).type.ndim != 1:
return vectorize_node_fallback(op, node, x, shape) return vectorize_node_fallback(op, node, x, shape)
axis = op.axis return SplitDims(axis=op.axis + batched_ndims).make_node(x, shape)
return SplitDims(axis=axis + batched_ndims).make_node(x, shape)
def split_dims( def split_dims(
x: TensorLike, x: TensorLike,
shape: ShapeValueType | Sequence[ShapeValueType], shape: ShapeValueType | Sequence[ShapeValueType],
axis: int | None = None, axis: int = 0,
) -> TensorVariable: ) -> TensorVariable:
"""Split a dimension of a tensor into multiple dimensions. """Split a dimension of a tensor into multiple dimensions.
...@@ -247,8 +246,8 @@ def split_dims( ...@@ -247,8 +246,8 @@ def split_dims(
The input tensor. The input tensor.
shape : int or sequence of int shape : int or sequence of int
The new shape to split the specified dimension into. The new shape to split the specified dimension into.
axis : int, optional axis : int, default 0
The dimension to split. If None, the input is assumed to be 1D and axis 0 is used. The dimension to split.
Returns Returns
------- -------
...@@ -259,22 +258,18 @@ def split_dims( ...@@ -259,22 +258,18 @@ def split_dims(
-------- --------
>>> import pytensor.tensor as pt >>> import pytensor.tensor as pt
>>> x = pt.tensor("x", shape=(6, 4, 6)) >>> x = pt.tensor("x", shape=(6, 4, 6))
>>> y = pt.split_dims(x, shape=(2, 3), axis=0) >>> y = pt.split_dims(x, shape=(2, 3))
>>> y.type.shape >>> y.type.shape
(2, 3, 4, 6) (2, 3, 4, 6)
>>> y = pt.split_dims(x, shape=(2, 3), axis=-1)
>>> y.type.shape
(6, 4, 2, 3)
""" """
x = as_tensor_variable(x) x = as_tensor_variable(x)
axis = normalize_axis_index(axis, x.ndim)
if axis is None:
if x.type.ndim != 1:
raise ValueError(
"split_dims can only be called with axis=None for 1d inputs"
)
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
# Convert scalar shape to 1d tuple (shape,) # Convert scalar shape to 1d tuple (shape,)
# Which is basically a specify_shape
if not isinstance(shape, Sequence): if not isinstance(shape, Sequence):
if isinstance(shape, TensorVariable | np.ndarray): if isinstance(shape, TensorVariable | np.ndarray):
if shape.ndim == 0: if shape.ndim == 0:
...@@ -313,8 +308,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]: ...@@ -313,8 +308,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
elif not isinstance(axes, Iterable): elif not isinstance(axes, Iterable):
raise TypeError("axes must be an int, an iterable of ints, or None") raise TypeError("axes must be an int, an iterable of ints, or None")
axes = tuple(axes)
if len(axes) == 0: if len(axes) == 0:
raise ValueError("axes=[] is ambiguous; use None to ravel all") raise ValueError("axes=[] is ambiguous; use None to ravel all")
...@@ -465,22 +458,10 @@ def pack( ...@@ -465,22 +458,10 @@ def pack(
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
f"but {keep_axes=} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." f"but {keep_axes=} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
) )
n_after_packed = n_dim - n_after
packed_shapes.append(input_tensor.shape[n_before:n_after_packed]) n_packed = n_dim - n_after - n_before
packed_shapes.append(input_tensor.shape[n_before : n_before + n_packed])
if n_dim == min_axes: joined = join_dims(input_tensor, n_before, n_packed)
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
# implied by the axes.
input_tensor = expand_dims(input_tensor, axis=n_before)
reshaped_tensors.append(input_tensor)
continue
# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
join_axes = range(n_before, n_after_packed)
joined = join_dims(input_tensor, tuple(join_axes))
reshaped_tensors.append(joined) reshaped_tensors.append(joined)
return join(n_before, *reshaped_tensors), packed_shapes return join(n_before, *reshaped_tensors), packed_shapes
......
...@@ -34,8 +34,9 @@ def local_join_dims_to_reshape(fgraph, node): ...@@ -34,8 +34,9 @@ def local_join_dims_to_reshape(fgraph, node):
""" """
(x,) = node.inputs (x,) = node.inputs
start_axis = node.op.start_axis op = node.op
n_axes = node.op.n_axes start_axis = op.start_axis
n_axes = op.n_axes
output_shape = [ output_shape = [
*x.shape[:start_axis], *x.shape[:start_axis],
......
...@@ -21,7 +21,7 @@ def test_local_split_dims_to_reshape(): ...@@ -21,7 +21,7 @@ def test_local_split_dims_to_reshape():
def test_local_join_dims_to_reshape(): def test_local_join_dims_to_reshape():
x = tensor("x", shape=(2, 2, 5, 1, 3)) x = tensor("x", shape=(2, 2, 5, 1, 3))
x_join = join_dims(x, axis=(1, 2, 3)) x_join = join_dims(x, start_axis=1, n_axes=3)
fg = FunctionGraph(inputs=[x], outputs=[x_join]) fg = FunctionGraph(inputs=[x], outputs=[x_join])
......
...@@ -20,32 +20,37 @@ def test_join_dims(): ...@@ -20,32 +20,37 @@ def test_join_dims():
rng = np.random.default_rng() rng = np.random.default_rng()
x = pt.tensor("x", shape=(2, 3, 4, 5)) x = pt.tensor("x", shape=(2, 3, 4, 5))
assert join_dims(x, axis=(0, 1)).type.shape == (6, 4, 5) assert join_dims(x).type.shape == (120,)
assert join_dims(x, axis=(1, 2)).type.shape == (2, 12, 5) assert join_dims(x, n_axes=1).type.shape == (2, 3, 4, 5)
assert join_dims(x, axis=(-1, -2)).type.shape == (2, 3, 20) assert join_dims(x, n_axes=0).type.shape == (1, 2, 3, 4, 5)
assert join_dims(x, axis=()).type.shape == (2, 3, 4, 5) assert join_dims(x, n_axes=2).type.shape == (6, 4, 5)
assert join_dims(x, axis=(2,)).type.shape == (2, 3, 4, 5) assert join_dims(x, start_axis=1, n_axes=2).type.shape == (2, 12, 5)
assert join_dims(x, start_axis=-3, n_axes=2).type.shape == (2, 12, 5)
assert join_dims(x, start_axis=2).type.shape == (2, 3, 20)
with pytest.raises(
IndexError,
match=r"Axis 5 is out of bounds for array of dimension 4",
):
join_dims(x, start_axis=5)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r"join_dims axis must be consecutive, got normalized axis: \(0, 2\)", match=r"JoinDims was asked to join dimensions 0 to 5, but input x has only 4 dimensions.",
): ):
_ = join_dims(x, axis=(0, 2)).type.shape == (8, 3, 5) join_dims(x, n_axes=5)
x_joined = join_dims(x, axis=(1, 2))
x_value = rng.normal(size=(2, 3, 4, 5)).astype(config.floatX) x_value = rng.normal(size=(2, 3, 4, 5)).astype(config.floatX)
np.testing.assert_allclose(
fn = function([x], x_joined, mode="FAST_COMPILE") join_dims(x, start_axis=1, n_axes=2).eval({x: x_value}),
x_value.reshape(2, 12, 5),
x_joined_value = fn(x_value) )
np.testing.assert_allclose(x_joined_value, x_value.reshape(2, 12, 5)) assert join_dims(x, 1, n_axes=1).eval({x: x_value}).shape == (2, 3, 4, 5)
assert join_dims(x, 1, n_axes=0).eval({x: x_value}).shape == (2, 1, 3, 4, 5)
assert join_dims(x, axis=(1,)).eval({x: x_value}).shape == (2, 3, 4, 5)
assert join_dims(x, axis=()).eval({x: x_value}).shape == (2, 3, 4, 5)
x = pt.tensor("x", shape=(3, 5)) x = pt.tensor("x", shape=(3, 5))
x_joined = join_dims(x, axis=(0, 1)) x_joined = join_dims(x)
x_batched = pt.tensor("x_batched", shape=(10, 3, 5)) x_batched = pt.tensor("x_batched", shape=(10, 3, 5))
x_joined_batched = vectorize_graph(x_joined, {x: x_batched}) x_joined_batched = vectorize_graph(x_joined, {x: x_batched})
...@@ -54,7 +59,7 @@ def test_join_dims(): ...@@ -54,7 +59,7 @@ def test_join_dims():
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX) x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15) assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)
utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value]) utt.verify_grad(lambda x: join_dims(x, start_axis=1, n_axes=2), [x_value])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -289,7 +294,7 @@ class TestPack: ...@@ -289,7 +294,7 @@ class TestPack:
output_vals = fn(**input_dict) output_vals = fn(**input_dict)
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val) np.testing.assert_allclose(input_val, output_val, strict=True)
def test_single_input(self): def test_single_input(self):
x = pt.matrix("x", shape=(2, 5)) x = pt.matrix("x", shape=(2, 5))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论