提交 96060bf9 authored 作者: Michael Osthege's avatar Michael Osthege

Fix tuple-related type hints

上级 c299bc77
......@@ -41,7 +41,7 @@ StorageMapType = Dict[Variable, StorageCellType]
ComputeMapType = Dict[Variable, List[bool]]
InputStorageType = List[StorageCellType]
OutputStorageType = List[StorageCellType]
ParamsInputType = Optional[Tuple[Any]]
ParamsInputType = Optional[Tuple[Any, ...]]
PerformMethodType = Callable[
[Apply, List[Any], OutputStorageType, ParamsInputType], None
]
......
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.utils import MethodNotDefined
......@@ -149,7 +149,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization."""
return []
def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]:
def c_code_cache_version(self) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be cached
......@@ -566,7 +566,7 @@ class CLinkerType(CLinkerObject):
"""
return ""
def c_code_cache_version(self) -> Union[Tuple, Tuple[int]]:
def c_code_cache_version(self) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this type.
An empty tuple indicates an "unversioned" type that will not
......
......@@ -240,7 +240,7 @@ def lquote_macro(txt: str) -> str:
return "\n".join(res)
def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]:
def get_sub_macros(sub: Dict[str, str]) -> Tuple[str, str]:
define_macros = []
undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
......@@ -533,7 +533,7 @@ class ExternalCOp(COp):
def get_c_macros(
self, node: Apply, name: str, check_input: Optional[bool] = None
) -> Union[Tuple[str], Tuple[str, str]]:
) -> Tuple[str, str]:
"Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s"
undef_template = "#undef %s"
......
......@@ -123,7 +123,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# These outer-inputs are indexed without offsets or storage wrap-around
add_inner_in_expr(outer_in_name, 0, None)
inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
zip(
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names,
op.info.mit_mot_in_slices
......@@ -157,7 +157,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: Tuple[int], storage_size: str
outer_in_name: str, tap_sizes: Tuple[int, ...], storage_size: str
):
tap_size = max(tap_sizes)
......
"""Symbolic Op for raising an exception."""
from textwrap import indent
from typing import Tuple
import numpy as np
......@@ -63,7 +62,7 @@ class CheckAndRaise(COp):
def __hash__(self):
return hash((self.msg, self.exc_type))
def make_node(self, value: Variable, *conds: Tuple[Variable]):
def make_node(self, value: Variable, *conds: Variable):
"""
Parameters
......
import copy
from typing import Tuple, Union
from typing import Tuple
import numpy as np
......@@ -18,7 +18,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype'
"""
__props__: Union[Tuple[str], Tuple[str, str]] = ("odtype",)
__props__: Tuple[str, ...] = ("odtype",)
def __init__(self, odtype):
self.odtype = odtype
......
......@@ -3998,7 +3998,7 @@ class Composite(ScalarOp, HasInnerGraph):
"""
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
init_param: Tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already
......
......@@ -137,7 +137,7 @@ try:
except ImportError:
pass
from typing import Tuple, Union
from typing import Tuple
import pytensor.scalar
from pytensor.compile.mode import optdb
......@@ -522,7 +522,7 @@ class GemmRelated(COp):
"""
__props__: Union[Tuple, Tuple[str]] = ()
__props__: Tuple[str, ...] = ()
def c_support_code(self, **kwargs):
# return cblas_header_text()
......
......@@ -1763,7 +1763,7 @@ def linspace(start, end, steps):
def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable, ...]]
) -> TensorVariable:
"""Broadcast an array to a new shape.
......
from functools import partial
from typing import Tuple, Union
from typing import Tuple
import numpy as np
......@@ -238,7 +238,7 @@ class Eig(Op):
"""
_numop = staticmethod(np.linalg.eig)
__props__: Union[Tuple, Tuple[str]] = ()
__props__: Tuple[str, ...] = ()
def make_node(self, x):
x = as_tensor_variable(x)
......
......@@ -28,7 +28,7 @@ def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None,
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`.
......
......@@ -544,7 +544,7 @@ _specify_shape = SpecifyShape()
def specify_shape(
x: Union[np.ndarray, Number, Variable],
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]],
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType, ...]],
):
"""Specify a fixed shape for a `Variable`.
......
......@@ -85,7 +85,7 @@ invalid_tensor_types = (
def indices_from_subtensor(
op_indices: Iterable[ScalarConstant],
idx_list: Optional[List[Union[Type, slice, Variable]]],
) -> Tuple[Union[slice, Variable]]:
) -> Union[slice, Variable]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
......@@ -411,7 +411,7 @@ def basic_shape(shape, indices):
Parameters
----------
shape: Tuple[int]
shape: Tuple[int, ...]
The shape of the array being indexed
indices: Sequence[Or[slice, NoneType]]
A sequence of basic indices used to index an array.
......@@ -473,9 +473,9 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
Parameters
----------
array_shape: Tuple[Variable]
array_shape: Tuple[Variable, ...]
Shape of the array being indexed.
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable]]]]
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable], ...]]]
Either the indices themselves or the shapes of each index--depending
on the value of `indices_are_shapes`.
indices_are_shapes: bool (Optional)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论