Unverified 提交 0b439c0f authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `L_Op` for `join_dims` and `split_dims` (#1812)

* Implement `L_Op` for `join_dims` and `split_dims` Improve type hints for `join_dims` and `split_dims` * Feedback
上级 0f864a33
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import cast as type_cast
import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple
from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType
from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable
......@@ -80,6 +81,19 @@ class JoinDims(Op):
out[0] = x.reshape(output_shape)
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x,) = inputs
(g_out,) = output_grads
x_shape = shape(x)
packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]
@_vectorize_node.register(JoinDims)
def _vectorize_joindims(op, node, x):
......@@ -97,14 +111,14 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
Parameters
----------
x : Variable
x : TensorLike
The input tensor.
axis : int or sequence of int, optional
The dimensions to join. If None, all dimensions are joined.
Returns
-------
joined_x : Variable
joined_x : TensorVariable
The reshaped tensor with joined dimensions.
Examples
......@@ -137,10 +151,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
start_axis = min(axis)
n_axes = len(axis)
return type_cast(
TensorVariable,
JoinDims(start_axis=start_axis, n_axes=n_axes)(x),
)
return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value]
class SplitDims(Op):
......@@ -191,6 +202,23 @@ class SplitDims(Op):
out[0] = x.reshape(output_shape)
def connection_pattern(self, node):
return [[True], [False]]
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x, _) = inputs
(g_out,) = output_grads
n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined]
axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
@_vectorize_node.register(SplitDims)
def _vectorize_splitdims(op, node, x, shape):
......@@ -224,7 +252,7 @@ def split_dims(
Returns
-------
split_x : Variable
split_x : TensorVariable
The reshaped tensor with split dimensions.
Examples
......@@ -253,13 +281,12 @@ def split_dims(
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return type_cast(TensorVariable, x.squeeze(axis=axis))
return squeeze(x, axis=axis) # type: ignore[no-any-return]
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]
split_op = SplitDims(axis=axis)
return type_cast(TensorVariable, split_op(x, shape))
return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]
def _analyze_axes_list(axes) -> tuple[int, int, int]:
......@@ -358,7 +385,7 @@ def pack(
Returns
-------
packed_tensor : TensorLike
packed_tensor : TensorVariable
The packed tensor with specified axes preserved and others raveled.
packed_shapes : list of ShapeValueType
A list containing the shapes of the raveled dimensions for each input tensor.
......@@ -430,7 +457,7 @@ def pack(
n_before, n_after, min_axes = _analyze_axes_list(axes)
reshaped_tensors: list[TensorVariable] = []
reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = []
for i, input_tensor in enumerate(tensor_list):
......@@ -488,7 +515,7 @@ def unpack(
Returns
-------
unpacked_tensors : list of TensorLike
unpacked_tensors : list of TensorVariable
A list of unpacked tensors with their original shapes restored.
"""
packed_input = as_tensor_variable(packed_input)
......
......@@ -2,6 +2,7 @@ import numpy as np
import pytest
import pytensor
import tests.unittest_tools as utt
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph
......@@ -52,6 +53,8 @@ def test_join_dims():
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)
utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value])
@pytest.mark.parametrize(
"axis, shape, expected_shape",
......@@ -77,6 +80,8 @@ def test_split_dims(axis, shape, expected_shape):
x_split_value = fn(x_value)
np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape))
utt.verify_grad(lambda x: split_dims(x, shape=shape, axis=axis), [x_value])
x = pt.tensor("x", shape=(10,))
x_split = split_dims(x, shape=(5, 2), axis=0)
x_batched = pt.tensor("x_batched", shape=(3, 10))
......@@ -115,7 +120,7 @@ def test_make_replacements_with_pack_unpack():
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
rewrite_graph(loss, include=("ShapeOpt", "specialize"))
rewrite_graph(loss, include=("ShapeOpt", "canonicalize"))
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论