提交 2ecf8523 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix type hints in reshape.py

Remove cases where type-hints are better than bad type-hints
上级 54a85007
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from itertools import pairwise from itertools import pairwise
from typing import TypeAlias
import numpy as np import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple from numpy.lib._array_utils_impl import normalize_axis_tuple
...@@ -9,15 +10,20 @@ from pytensor.gradient import DisconnectedType ...@@ -9,15 +10,20 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
ShapeValueType: TypeAlias = (
int | np.integer | ScalarVariable | TensorVariable | np.ndarray
)
class JoinDims(Op): class JoinDims(Op):
__props__ = ( __props__ = (
"start_axis", "start_axis",
...@@ -81,16 +87,11 @@ class JoinDims(Op): ...@@ -81,16 +87,11 @@ class JoinDims(Op):
out[0] = x.reshape(output_shape) out[0] = x.reshape(output_shape)
def L_op( def L_op(self, inputs, outputs, output_grads):
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x,) = inputs (x,) = inputs
(g_out,) = output_grads (g_out,) = output_grads
x_shape = shape(x) x_shape = x.shape
packed_shape = [x_shape[i] for i in self.axis_range] packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)] return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]
...@@ -163,7 +164,7 @@ class SplitDims(Op): ...@@ -163,7 +164,7 @@ class SplitDims(Op):
raise ValueError("SplitDims axis must be non-negative") raise ValueError("SplitDims axis must be non-negative")
self.axis = axis self.axis = axis
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] def make_node(self, x, shape):
x = as_tensor_variable(x) x = as_tensor_variable(x)
shape = as_tensor_variable(shape, dtype=int, ndim=1) shape = as_tensor_variable(shape, dtype=int, ndim=1)
...@@ -205,16 +206,11 @@ class SplitDims(Op): ...@@ -205,16 +206,11 @@ class SplitDims(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [False]] return [[True], [False]]
def L_op( def L_op(self, inputs, outputs, output_grads):
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x, _) = inputs (x, _) = inputs
(g_out,) = output_grads (g_out,) = output_grads
n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined] n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes)) axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()] return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
...@@ -372,7 +368,7 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]: ...@@ -372,7 +368,7 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
def pack( def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None *tensors: TensorLike, axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[ShapeValueType]]: ) -> tuple[TensorVariable, list[TensorVariable]]:
""" """
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
...@@ -458,7 +454,7 @@ def pack( ...@@ -458,7 +454,7 @@ def pack(
n_before, n_after, min_axes = _analyze_axes_list(axes) n_before, n_after, min_axes = _analyze_axes_list(axes)
reshaped_tensors: list[Variable] = [] reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = [] packed_shapes: list[TensorVariable] = []
for i, input_tensor in enumerate(tensor_list): for i, input_tensor in enumerate(tensor_list):
n_dim = input_tensor.ndim n_dim = input_tensor.ndim
...@@ -492,7 +488,7 @@ def pack( ...@@ -492,7 +488,7 @@ def pack(
def unpack( def unpack(
packed_input: TensorLike, packed_input: TensorLike,
axes: int | Sequence[int] | None, axes: int | Sequence[int] | None,
packed_shapes: list[ShapeValueType], packed_shapes: Sequence[ShapeValueType],
) -> list[TensorVariable]: ) -> list[TensorVariable]:
""" """
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论