提交 906e1424 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Ricardo Vieira

Fix type hints here and there

上级 7a00b885
......@@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
@singledispatch
def _as_symbolic(x, **kwargs) -> Variable:
def _as_symbolic(x: Any, **kwargs) -> Variable:
from pytensor.tensor import as_tensor_variable
return as_tensor_variable(x, **kwargs)
......
......@@ -1302,8 +1302,8 @@ def clone_node_and_cache(
def clone_get_equiv(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
inputs: Iterable[Variable],
outputs: Reversible[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
......
import warnings
from collections.abc import Sequence
from copy import copy
from typing import cast
from typing import Any, cast
import numpy as np
......@@ -218,6 +218,7 @@ class RandomVariable(Op):
from pytensor.tensor.extra_ops import broadcast_shape_iter
supp_shape: tuple[Any]
if self.ndim_supp == 0:
supp_shape = ()
else:
......
......@@ -147,7 +147,9 @@ def explicit_expand_dims(
return new_params
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
def compute_batch_shape(
params: Sequence[TensorVariable], ndims_params: Sequence[int]
) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params)
batch_params = [
param[(..., *(0,) * core_ndim)]
......
......@@ -144,14 +144,14 @@ class Shape(COp):
_shape = Shape()
def shape(x: np.ndarray | Number | Variable) -> Variable:
def shape(x: np.ndarray | Number | Variable) -> TensorVariable:
"""Return the shape of `x`."""
if not isinstance(x, Variable):
# The following is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
return cast(Variable, _shape(x))
return cast(TensorVariable, _shape(x))
@_get_vector_length.register(Shape) # type: ignore
......@@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
# TODO: Why not use uint64?
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],) # type: ignore
res += (symbolic_shape[i],)
return res
......
......@@ -138,7 +138,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
shape = self.shape
return type(self)(dtype, shape, name=self.name)
def filter(self, data, strict=False, allow_downcast=None):
def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
"""Convert `data` to something which can be associated to a `TensorVariable`.
This function is not meant to be called in user code. It is for
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论