提交 947b9409 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make reshape ndim keyword only

上级 fe8804fa
import warnings
from collections.abc import Sequence
from numbers import Number
from textwrap import dedent
from typing import cast
from typing import TYPE_CHECKING, Union, cast
from typing import cast as typing_cast
import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
......@@ -24,6 +26,9 @@ from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable
if TYPE_CHECKING:
from pytensor.tensor import TensorLike
ShapeValueType = None | np.integer | int | Variable
......@@ -842,9 +847,14 @@ def _vectorize_reshape(op, node, x, shape):
return reshape(x, new_shape, ndim=len(new_shape)).owner
def reshape(x, newshape, ndim=None):
def reshape(
x: "TensorLike",
newshape: Union["TensorLike", Sequence["TensorLike"]],
*,
ndim: int | None = None,
) -> TensorVariable:
if ndim is None:
newshape = ptb.as_tensor_variable(newshape)
newshape = ptb.as_tensor_variable(newshape) # type: ignore
if newshape.type.ndim != 1:
raise TypeError(
"New shape in reshape must be a vector or a list/tuple of"
......@@ -862,7 +872,7 @@ def reshape(x, newshape, ndim=None):
)
op = Reshape(ndim)
rval = op(x, newshape)
return rval
return typing_cast(TensorVariable, rval)
def shape_padleft(t, n_ones=1):
......
......@@ -918,7 +918,7 @@ def _direct_solve_discrete_lyapunov(
vec_Q = Q.ravel()
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
return cast(TensorVariable, reshape(vec_X, A.shape))
return reshape(vec_X, A.shape)
def solve_discrete_lyapunov(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论