提交 c04185db authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Rename (un)pack axes argument to keep_axes.

Also: * Allow default `None` on unpack
上级 d64f5962
......@@ -367,7 +367,7 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None
*tensors: TensorLike, keep_axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[TensorVariable]]:
"""
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
......@@ -401,8 +401,8 @@ def pack(
Examples
--------
The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent
to ``join(0, *[t.ravel() for t in tensors])``:
The easiest way to understand pack is through examples.
The simplest case is using the default keep_axes=None, which is equivalent to ``concatenate([t.ravel() for t in tensors])``:
.. code-block:: python
import pytensor.tensor as pt
......@@ -410,19 +410,20 @@ def pack(
x = pt.tensor("x", shape=(2, 3))
y = pt.tensor("y", shape=(4, 5, 6))
packed_tensor, packed_shapes = pt.pack(x, y, axes=None)
packed_tensor, packed_shapes = pt.pack(x, y)
# packed_tensor has shape (6 + 120,) == (126,)
# packed_shapes is [(2, 3), (4, 5, 6)]
If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors
must have the same size along the preserved axis. For example, using axes=0:
If we want to preserve a single axis, we can use either positive or negative indexing.
Notice that all tensors must have the same size along the preserved axis.
For example, using keep_axes=0:
.. code-block:: python
import pytensor.tensor as pt
x = pt.tensor("x", shape=(2, 3))
y = pt.tensor("y", shape=(2, 5, 6))
packed_tensor, packed_shapes = pt.pack(x, y, axes=0)
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=0)
# packed_tensor has shape (2, 3 + 30) == (2, 33)
# packed_shapes is [(3,), (5, 6)]
......@@ -434,7 +435,7 @@ def pack(
x = pt.tensor("x", shape=(4, 2, 3))
y = pt.tensor("y", shape=(5, 2, 3))
packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1))
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(-2, -1))
# packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3)
# packed_shapes is [(4,), (5,
......@@ -445,13 +446,13 @@ def pack(
x = pt.tensor("x", shape=(2, 4, 3))
y = pt.tensor("y", shape=(2, 5, 3))
packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1))
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(0, -1))
# packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
# packed_shapes is [(4,), (5,)]
"""
tensor_list = [as_tensor_variable(t) for t in tensors]
n_before, n_after, min_axes = _analyze_axes_list(axes)
n_before, n_after, min_axes = _analyze_axes_list(keep_axes)
reshaped_tensors: list[Variable] = []
packed_shapes: list[TensorVariable] = []
......@@ -462,7 +463,7 @@ def pack(
if n_dim < min_axes:
raise ValueError(
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
f"but {keep_axes=} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
)
n_after_packed = n_dim - n_after
packed_shapes.append(input_tensor.shape[n_before:n_after_packed])
......@@ -487,8 +488,8 @@ def pack(
def unpack(
packed_input: TensorLike,
axes: int | Sequence[int] | None,
packed_shapes: Sequence[ShapeValueType],
keep_axes: int | Sequence[int] | None = None,
) -> list[TensorVariable]:
"""
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
......@@ -504,10 +505,10 @@ def unpack(
----------
packed_input : TensorLike
The packed tensor to be unpacked.
axes : int, sequence of int, or None
Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used.
packed_shapes : list of ShapeValueType
A list containing the shapes of the raveled dimensions for each output tensor.
keep_axes : int, sequence of int, optional
Axes that were preserved during packing. Default is None
Returns
-------
......@@ -515,26 +516,30 @@ def unpack(
A list of unpacked tensors with their original shapes restored.
"""
packed_input = as_tensor_variable(packed_input)
if axes is None:
if keep_axes is None:
if packed_input.ndim != 1:
raise ValueError(
"unpack can only be called with keep_axis=None for 1d inputs"
)
split_axis = 0
else:
axes = normalize_axis_tuple(axes, ndim=packed_input.ndim)
keep_axes = normalize_axis_tuple(keep_axes, ndim=packed_input.ndim)
try:
[split_axis] = (i for i in range(packed_input.ndim) if i not in axes)
[split_axis] = (i for i in range(packed_input.ndim) if i not in keep_axes)
except ValueError as err:
raise ValueError(
"Unpack must have exactly one more dimension that implied by axes"
f"unpack input must have exactly one more dimension that implied by keep_axes. "
f"{packed_input} has {packed_input.type.ndim} dimensions, expected {len(keep_axes) + 1}"
) from err
n_splits = len(packed_shapes)
if n_splits == 1:
# If there is only one tensor to unpack, no need to split
split_inputs = [packed_input]
else:
split_inputs = split(
packed_input,
splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
n_splits=len(packed_shapes),
axis=split_axis,
)
......
......@@ -117,9 +117,9 @@ def test_make_replacements_with_pack_unpack():
loss = (x + y.sum() + z.sum()) ** 2
flat_packed, packed_shapes = pack(x, y, z, axes=None)
flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
new_outputs = unpack(new_input, packed_shapes=packed_shapes)
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
rewrite_graph(loss, include=("ShapeOpt", "canonicalize"))
......@@ -198,7 +198,7 @@ class TestPack:
}
# Simple case, reduce all axes, equivalent to einops '*'
packed_tensor, packed_shapes = pack(x, y, z, axes=None)
packed_tensor, packed_shapes = pack(x, y, z)
assert packed_tensor.type.shape == (15,)
for tensor, packed_shape in zip([x, y, z], packed_shapes):
assert packed_shape.type.shape == (tensor.ndim,)
......@@ -211,9 +211,9 @@ class TestPack:
# x is scalar, so pack will raise:
with pytest.raises(
ValueError,
match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but axes=0 assumes at least 1 dimension\.",
match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but keep_axes=0 assumes at least 1 dimension\.",
):
pack(x, y, z, axes=0)
pack(x, y, z, keep_axes=0)
# With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs
x = pt.tensor("x", shape=(3,))
......@@ -224,13 +224,13 @@ class TestPack:
match=r"all input array dimensions other than the specified `axis` \(1\) must match exactly, or be unknown "
r"\(None\), but along dimension 0, the inputs shapes are incompatible: \[3 5 3\]",
):
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
packed_tensor, packed_shapes = pack(x, y, z, keep_axes=0)
packed_tensor.eval(input_dict)
# Valid case, preserve first axis, equivalent to einops 'i *'
y = pt.tensor("y", shape=(3, 5))
z = pt.tensor("z", shape=(3, 3, 3))
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
packed_tensor, packed_shapes = pack(x, y, z, keep_axes=0)
input_dict = {
variable.name: np.zeros(variable.type.shape, dtype=config.floatX)
for variable in [x, y, z]
......@@ -253,10 +253,10 @@ class TestPack:
ValueError,
match=r"Positive axes must be contiguous",
):
pack(x, y, z, axes=[0, 3])
pack(x, y, z, keep_axes=[0, 3])
z = pt.tensor("z", shape=(3, 1, 7, 2))
packed_tensor, packed_shapes = pack(x, y, z, axes=[0, -1])
packed_tensor, packed_shapes = pack(x, y, z, keep_axes=[0, -1])
input_dict = {
variable.name: np.zeros(variable.type.shape, dtype=config.floatX)
for variable in [x, y, z]
......@@ -277,8 +277,8 @@ class TestPack:
y = pt.tensor("y", shape=(3, 3, 5))
z = pt.tensor("z", shape=(1, 3, 5))
flat_packed, packed_shapes = pack(x, y, z, axes=axes)
new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes)
flat_packed, packed_shapes = pack(x, y, z, keep_axes=axes)
new_outputs = unpack(flat_packed, packed_shapes=packed_shapes, keep_axes=axes)
fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE")
......@@ -291,11 +291,17 @@ class TestPack:
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val)
def test_single_input(self):
x = pt.matrix("x", shape=(2, 5))
packed_x, packed_shapes = pt.pack(x)
assert packed_x.type.shape == (10,)
[x_again] = unpack(packed_x, packed_shapes)
assert x_again.type.shape == (2, 5)
def test_unpack_connection():
def test_unpack_connection(self):
x = pt.vector("x")
d0 = pt.scalar("d0", dtype=int)
d1 = pt.scalar("d1", dtype=int)
x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1])
x0, x1 = pt.unpack(x, packed_shapes=[d0, d1])
out = x0.sum() + x1.sum()
assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论