提交 2ecf8523 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix type hints in reshape.py

Remove cases where type-hints are better than bad type-hints
上级 54a85007
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import TypeAlias
import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple
......@@ -9,15 +10,20 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable
ShapeValueType: TypeAlias = (
int | np.integer | ScalarVariable | TensorVariable | np.ndarray
)
class JoinDims(Op):
__props__ = (
"start_axis",
......@@ -81,16 +87,11 @@ class JoinDims(Op):
out[0] = x.reshape(output_shape)
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
(x,) = inputs
(g_out,) = output_grads
x_shape = shape(x)
x_shape = x.shape
packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]
......@@ -163,7 +164,7 @@ class SplitDims(Op):
raise ValueError("SplitDims axis must be non-negative")
self.axis = axis
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
def make_node(self, x, shape):
x = as_tensor_variable(x)
shape = as_tensor_variable(shape, dtype=int, ndim=1)
......@@ -205,16 +206,11 @@ class SplitDims(Op):
def connection_pattern(self, node):
return [[True], [False]]
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
(x, _) = inputs
(g_out,) = output_grads
n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined]
n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes))
return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
......@@ -372,7 +368,7 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[ShapeValueType]]:
) -> tuple[TensorVariable, list[TensorVariable]]:
"""
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
......@@ -458,7 +454,7 @@ def pack(
n_before, n_after, min_axes = _analyze_axes_list(axes)
reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = []
packed_shapes: list[TensorVariable] = []
for i, input_tensor in enumerate(tensor_list):
n_dim = input_tensor.ndim
......@@ -492,7 +488,7 @@ def pack(
def unpack(
packed_input: TensorLike,
axes: int | Sequence[int] | None,
packed_shapes: list[ShapeValueType],
packed_shapes: Sequence[ShapeValueType],
) -> list[TensorVariable]:
"""
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论