Unverified 提交 0b439c0f authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement `L_Op` for `join_dims` and `split_dims` (#1812)

* Implement `L_Op` for `join_dims` and `split_dims` Improve type hints for `join_dims` and `split_dims` * Feedback
上级 0f864a33
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from itertools import pairwise from itertools import pairwise
from typing import cast as type_cast
import numpy as np import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple from numpy.lib._array_utils_impl import normalize_axis_tuple
from pytensor import Variable from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -80,6 +81,19 @@ class JoinDims(Op): ...@@ -80,6 +81,19 @@ class JoinDims(Op):
out[0] = x.reshape(output_shape) out[0] = x.reshape(output_shape)
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x,) = inputs
(g_out,) = output_grads
x_shape = shape(x)
packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]
@_vectorize_node.register(JoinDims) @_vectorize_node.register(JoinDims)
def _vectorize_joindims(op, node, x): def _vectorize_joindims(op, node, x):
...@@ -97,14 +111,14 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV ...@@ -97,14 +111,14 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
Parameters Parameters
---------- ----------
x : Variable x : TensorLike
The input tensor. The input tensor.
axis : int or sequence of int, optional axis : int or sequence of int, optional
The dimensions to join. If None, all dimensions are joined. The dimensions to join. If None, all dimensions are joined.
Returns Returns
------- -------
joined_x : Variable joined_x : TensorVariable
The reshaped tensor with joined dimensions. The reshaped tensor with joined dimensions.
Examples Examples
...@@ -137,10 +151,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV ...@@ -137,10 +151,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
start_axis = min(axis) start_axis = min(axis)
n_axes = len(axis) n_axes = len(axis)
return type_cast( return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value]
TensorVariable,
JoinDims(start_axis=start_axis, n_axes=n_axes)(x),
)
class SplitDims(Op): class SplitDims(Op):
...@@ -191,6 +202,23 @@ class SplitDims(Op): ...@@ -191,6 +202,23 @@ class SplitDims(Op):
out[0] = x.reshape(output_shape) out[0] = x.reshape(output_shape)
def connection_pattern(self, node):
return [[True], [False]]
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x, _) = inputs
(g_out,) = output_grads
n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined]
axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
@_vectorize_node.register(SplitDims) @_vectorize_node.register(SplitDims)
def _vectorize_splitdims(op, node, x, shape): def _vectorize_splitdims(op, node, x, shape):
...@@ -224,7 +252,7 @@ def split_dims( ...@@ -224,7 +252,7 @@ def split_dims(
Returns Returns
------- -------
split_x : Variable split_x : TensorVariable
The reshaped tensor with split dimensions. The reshaped tensor with split dimensions.
Examples Examples
...@@ -253,13 +281,12 @@ def split_dims( ...@@ -253,13 +281,12 @@ def split_dims(
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4) # (3, ) and (3, 3) to (3, 4)
return type_cast(TensorVariable, x.squeeze(axis=axis)) return squeeze(x, axis=axis) # type: ignore[no-any-return]
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc] [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type] shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]
split_op = SplitDims(axis=axis) return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]
return type_cast(TensorVariable, split_op(x, shape))
def _analyze_axes_list(axes) -> tuple[int, int, int]: def _analyze_axes_list(axes) -> tuple[int, int, int]:
...@@ -358,7 +385,7 @@ def pack( ...@@ -358,7 +385,7 @@ def pack(
Returns Returns
------- -------
packed_tensor : TensorLike packed_tensor : TensorVariable
The packed tensor with specified axes preserved and others raveled. The packed tensor with specified axes preserved and others raveled.
packed_shapes : list of ShapeValueType packed_shapes : list of ShapeValueType
A list containing the shapes of the raveled dimensions for each input tensor. A list containing the shapes of the raveled dimensions for each input tensor.
...@@ -430,7 +457,7 @@ def pack( ...@@ -430,7 +457,7 @@ def pack(
n_before, n_after, min_axes = _analyze_axes_list(axes) n_before, n_after, min_axes = _analyze_axes_list(axes)
reshaped_tensors: list[TensorVariable] = [] reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = [] packed_shapes: list[ShapeValueType] = []
for i, input_tensor in enumerate(tensor_list): for i, input_tensor in enumerate(tensor_list):
...@@ -488,7 +515,7 @@ def unpack( ...@@ -488,7 +515,7 @@ def unpack(
Returns Returns
------- -------
unpacked_tensors : list of TensorLike unpacked_tensors : list of TensorVariable
A list of unpacked tensors with their original shapes restored. A list of unpacked tensors with their original shapes restored.
""" """
packed_input = as_tensor_variable(packed_input) packed_input = as_tensor_variable(packed_input)
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
import tests.unittest_tools as utt
from pytensor import config, function from pytensor import config, function
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph from pytensor.graph import rewrite_graph, vectorize_graph
...@@ -52,6 +53,8 @@ def test_join_dims(): ...@@ -52,6 +53,8 @@ def test_join_dims():
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX) x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15) assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)
utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"axis, shape, expected_shape", "axis, shape, expected_shape",
...@@ -77,6 +80,8 @@ def test_split_dims(axis, shape, expected_shape): ...@@ -77,6 +80,8 @@ def test_split_dims(axis, shape, expected_shape):
x_split_value = fn(x_value) x_split_value = fn(x_value)
np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape)) np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape))
utt.verify_grad(lambda x: split_dims(x, shape=shape, axis=axis), [x_value])
x = pt.tensor("x", shape=(10,)) x = pt.tensor("x", shape=(10,))
x_split = split_dims(x, shape=(5, 2), axis=0) x_split = split_dims(x, shape=(5, 2), axis=0)
x_batched = pt.tensor("x_batched", shape=(3, 10)) x_batched = pt.tensor("x_batched", shape=(3, 10))
...@@ -115,7 +120,7 @@ def test_make_replacements_with_pack_unpack(): ...@@ -115,7 +120,7 @@ def test_make_replacements_with_pack_unpack():
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes) new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
rewrite_graph(loss, include=("ShapeOpt", "specialize")) rewrite_graph(loss, include=("ShapeOpt", "canonicalize"))
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE") fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论