提交 4b897162 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Implement `Pack` and `Unpack`

上级 a0be97e8
差异被折叠。
import numpy as np
import pytest
import pytensor
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.graph import vectorize_graph
from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.tensor.reshape import (
_analyze_axes_list,
join_dims,
pack,
split_dims,
unpack,
)
......@@ -95,3 +99,187 @@ def test_split_size_zero_shape():
x_split_value = fn(x_value)
np.testing.assert_allclose(x_split_value, x_value.squeeze(0))
def test_make_replacements_with_pack_unpack():
rng = np.random.default_rng()
x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
loss = (x + y.sum() + z.sum()) ** 2
flat_packed, packed_shapes = pack(x, y, z, axes=None)
new_input = flat_packed.type()
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
rewrite_graph(loss, include=("ShapeOpt", "specialize"))
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
input_vals = [
rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z]
]
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
output_val = fn(flat_inputs)
assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)
class TestPack:
@pytest.mark.parametrize(
"axes, expected",
[
(None, [0, 0, 0]), # '*'
([0, 1], [2, 0, 2]), # 'i j *'
([-1], [0, 1, 1]), # '* k'
([-2, -1], [0, 2, 2]), # '* i j'
([0, -1], [1, 1, 2]), # 'i * k'
([0, 1, 2, -1], [3, 1, 4]), # 'i j k * l'
],
ids=[
"ravel_all",
"keep_first_two",
"keep_last",
"ravel_start",
"first_and_last",
"complex_case",
],
)
def test_analyze_axes_list_valid(self, axes, expected):
outputs = _analyze_axes_list(axes)
names = ["n_before", "n_after", "min_axes"]
for out, exp, name in zip(outputs, expected, names, strict=True):
assert out == exp, f"Expected {exp}, got {out} for {name}"
def test_analyze_axes_list_invalid(self):
# Positive only but not contiguous
with pytest.raises(ValueError, match="Positive axes must be contiguous"):
_analyze_axes_list([1, 3])
# Negative only but not contiguous
with pytest.raises(ValueError, match="Negative axes must be contiguous"):
_analyze_axes_list([-3, -1])
# Mixed up positive and negative
with pytest.raises(ValueError, match="Negative axes must come after positive"):
_analyze_axes_list([0, 1, -2, 4])
# Duplicate axes
with pytest.raises(ValueError, match="axes must have no duplicates"):
_analyze_axes_list([0, 0])
# Not monotonic
with pytest.raises(ValueError, match="Axes must be strictly increasing"):
_analyze_axes_list([0, 2, 1])
# Negative before positive
with pytest.raises(ValueError, match="Negative axes must come after positive"):
_analyze_axes_list([-1, 0])
def test_pack_basic(self):
# rng = np.random.default_rng()
x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
input_dict = {
variable.name: np.zeros(variable.type.shape, dtype=config.floatX)
for variable in [x, y, z]
}
# Simple case, reduce all axes, equivalent to einops '*'
packed_tensor, packed_shapes = pack(x, y, z, axes=None)
assert packed_tensor.type.shape == (15,)
for tensor, packed_shape in zip([x, y, z], packed_shapes):
assert packed_shape.type.shape == (tensor.ndim,)
np.testing.assert_allclose(
packed_shape.eval(input_dict, on_unused_input="ignore"),
tensor.type.shape,
)
# To preserve an axis, all inputs need at least one dimension, and the preserved axis has to agree.
# x is scalar, so pack will raise:
with pytest.raises(
ValueError,
match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but axes=0 assumes at least 1 dimension\.",
):
pack(x, y, z, axes=0)
# With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs
x = pt.tensor("x", shape=(3,))
input_dict["x"] = np.zeros((3,), dtype=config.floatX)
with pytest.raises(
ValueError,
match=r"all input array dimensions other than the specified `axis` \(1\) must match exactly, or be unknown "
r"\(None\), but along dimension 0, the inputs shapes are incompatible: \[3 5 3\]",
):
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
packed_tensor.eval(input_dict)
# Valid case, preserve first axis, equivalent to einops 'i *'
y = pt.tensor("y", shape=(3, 5))
z = pt.tensor("z", shape=(3, 3, 3))
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
input_dict = {
variable.name: np.zeros(variable.type.shape, dtype=config.floatX)
for variable in [x, y, z]
}
assert packed_tensor.type.shape == (3, 15)
for tensor, packed_shape in zip([x, y, z], packed_shapes):
assert packed_shape.type.shape == (tensor.ndim - 1,)
np.testing.assert_allclose(
packed_shape.eval(input_dict, on_unused_input="ignore"),
tensor.type.shape[1:],
)
# More complex case, preserve last axis implicitly, equivalent to einops 'i * k'. This introduces a max
# dimension condition on the input shapes
x = pt.tensor("x", shape=(3, 2))
y = pt.tensor("y", shape=(3, 5, 2))
z = pt.tensor("z", shape=(3, 1, 7, 5, 2))
with pytest.raises(
ValueError,
match=r"Positive axes must be contiguous",
):
pack(x, y, z, axes=[0, 3])
z = pt.tensor("z", shape=(3, 1, 7, 2))
packed_tensor, packed_shapes = pack(x, y, z, axes=[0, -1])
input_dict = {
variable.name: np.zeros(variable.type.shape, dtype=config.floatX)
for variable in [x, y, z]
}
assert packed_tensor.type.shape == (3, 13, 2)
for tensor, packed_shape in zip([x, y, z], packed_shapes):
assert packed_shape.type.shape == (tensor.ndim - 2,)
np.testing.assert_allclose(
packed_shape.eval(input_dict, on_unused_input="ignore"),
tensor.type.shape[1:-1],
)
@pytest.mark.parametrize("axes", [-1])
def test_pack_unpack_round_trip(self, axes):
rng = np.random.default_rng()
x = pt.tensor("x", shape=(3, 5))
y = pt.tensor("y", shape=(3, 3, 5))
z = pt.tensor("z", shape=(1, 3, 5))
flat_packed, packed_shapes = pack(x, y, z, axes=axes)
new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes)
fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE")
input_dict = {
var.name: rng.normal(size=var.type.shape).astype(config.floatX)
for var in [x, y, z]
}
output_vals = fn(**input_dict)
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论