提交 906e1424 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Ricardo Vieira

Fix type hints here and there

上级 7a00b885
...@@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable: ...@@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
@singledispatch @singledispatch
def _as_symbolic(x, **kwargs) -> Variable: def _as_symbolic(x: Any, **kwargs) -> Variable:
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
return as_tensor_variable(x, **kwargs) return as_tensor_variable(x, **kwargs)
......
...@@ -1302,8 +1302,8 @@ def clone_node_and_cache( ...@@ -1302,8 +1302,8 @@ def clone_node_and_cache(
def clone_get_equiv( def clone_get_equiv(
inputs: Sequence[Variable], inputs: Iterable[Variable],
outputs: Sequence[Variable], outputs: Reversible[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
......
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
from typing import cast from typing import Any, cast
import numpy as np import numpy as np
...@@ -218,6 +218,7 @@ class RandomVariable(Op): ...@@ -218,6 +218,7 @@ class RandomVariable(Op):
from pytensor.tensor.extra_ops import broadcast_shape_iter from pytensor.tensor.extra_ops import broadcast_shape_iter
supp_shape: tuple[Any]
if self.ndim_supp == 0: if self.ndim_supp == 0:
supp_shape = () supp_shape = ()
else: else:
......
...@@ -147,7 +147,9 @@ def explicit_expand_dims( ...@@ -147,7 +147,9 @@ def explicit_expand_dims(
return new_params return new_params
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: def compute_batch_shape(
params: Sequence[TensorVariable], ndims_params: Sequence[int]
) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params) params = explicit_expand_dims(params, ndims_params)
batch_params = [ batch_params = [
param[(..., *(0,) * core_ndim)] param[(..., *(0,) * core_ndim)]
......
...@@ -144,14 +144,14 @@ class Shape(COp): ...@@ -144,14 +144,14 @@ class Shape(COp):
_shape = Shape() _shape = Shape()
def shape(x: np.ndarray | Number | Variable) -> Variable: def shape(x: np.ndarray | Number | Variable) -> TensorVariable:
"""Return the shape of `x`.""" """Return the shape of `x`."""
if not isinstance(x, Variable): if not isinstance(x, Variable):
# The following is a type error in Python 3.9 but not 3.12. # The following is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12. # Thus we need to ignore unused-ignore on 3.12.
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore] x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
return cast(Variable, _shape(x)) return cast(TensorVariable, _shape(x))
@_get_vector_length.register(Shape) # type: ignore @_get_vector_length.register(Shape) # type: ignore
...@@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: ...@@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64? # TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),) res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else: else:
res += (symbolic_shape[i],) # type: ignore res += (symbolic_shape[i],)
return res return res
......
...@@ -138,7 +138,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -138,7 +138,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
shape = self.shape shape = self.shape
return type(self)(dtype, shape, name=self.name) return type(self)(dtype, shape, name=self.name)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
"""Convert `data` to something which can be associated to a `TensorVariable`. """Convert `data` to something which can be associated to a `TensorVariable`.
This function is not meant to be called in user code. It is for This function is not meant to be called in user code. It is for
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论