提交 96060bf9 authored 作者: Michael Osthege's avatar Michael Osthege

Fix tuple-related type hints

上级 c299bc77
...@@ -41,7 +41,7 @@ StorageMapType = Dict[Variable, StorageCellType] ...@@ -41,7 +41,7 @@ StorageMapType = Dict[Variable, StorageCellType]
ComputeMapType = Dict[Variable, List[bool]] ComputeMapType = Dict[Variable, List[bool]]
InputStorageType = List[StorageCellType] InputStorageType = List[StorageCellType]
OutputStorageType = List[StorageCellType] OutputStorageType = List[StorageCellType]
ParamsInputType = Optional[Tuple[Any]] ParamsInputType = Optional[Tuple[Any, ...]]
PerformMethodType = Callable[ PerformMethodType = Callable[
[Apply, List[Any], OutputStorageType, ParamsInputType], None [Apply, List[Any], OutputStorageType, ParamsInputType], None
] ]
......
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Dict, List, Tuple, Union from typing import Callable, Dict, List, Tuple
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
...@@ -149,7 +149,7 @@ class CLinkerObject: ...@@ -149,7 +149,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization.""" """Return a list of code snippets to be inserted in module initialization."""
return [] return []
def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]: def c_code_cache_version(self) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this `Op`. """Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be cached An empty tuple indicates an "unversioned" `Op` that will not be cached
...@@ -566,7 +566,7 @@ class CLinkerType(CLinkerObject): ...@@ -566,7 +566,7 @@ class CLinkerType(CLinkerObject):
""" """
return "" return ""
def c_code_cache_version(self) -> Union[Tuple, Tuple[int]]: def c_code_cache_version(self) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this type. """Return a tuple of integers indicating the version of this type.
An empty tuple indicates an "unversioned" type that will not An empty tuple indicates an "unversioned" type that will not
......
...@@ -240,7 +240,7 @@ def lquote_macro(txt: str) -> str: ...@@ -240,7 +240,7 @@ def lquote_macro(txt: str) -> str:
return "\n".join(res) return "\n".join(res)
def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]: def get_sub_macros(sub: Dict[str, str]) -> Tuple[str, str]:
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}") define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
...@@ -533,7 +533,7 @@ class ExternalCOp(COp): ...@@ -533,7 +533,7 @@ class ExternalCOp(COp):
def get_c_macros( def get_c_macros(
self, node: Apply, name: str, check_input: Optional[bool] = None self, node: Apply, name: str, check_input: Optional[bool] = None
) -> Union[Tuple[str], Tuple[str, str]]: ) -> Tuple[str, str]:
"Construct a pair of C ``#define`` and ``#undef`` code strings." "Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s" define_template = "#define %s %s"
undef_template = "#undef %s" undef_template = "#undef %s"
......
...@@ -123,7 +123,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -123,7 +123,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# These outer-inputs are indexed without offsets or storage wrap-around # These outer-inputs are indexed without offsets or storage wrap-around
add_inner_in_expr(outer_in_name, 0, None) add_inner_in_expr(outer_in_name, 0, None)
inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict( inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
zip( zip(
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names, outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names,
op.info.mit_mot_in_slices op.info.mit_mot_in_slices
...@@ -157,7 +157,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -157,7 +157,7 @@ def numba_funcify_Scan(op, node, **kwargs):
# storage array like a circular buffer, and that's why we need to track the # storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset. # storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt( def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: Tuple[int], storage_size: str outer_in_name: str, tap_sizes: Tuple[int, ...], storage_size: str
): ):
tap_size = max(tap_sizes) tap_size = max(tap_sizes)
......
"""Symbolic Op for raising an exception.""" """Symbolic Op for raising an exception."""
from textwrap import indent from textwrap import indent
from typing import Tuple
import numpy as np import numpy as np
...@@ -63,7 +62,7 @@ class CheckAndRaise(COp): ...@@ -63,7 +62,7 @@ class CheckAndRaise(COp):
def __hash__(self): def __hash__(self):
return hash((self.msg, self.exc_type)) return hash((self.msg, self.exc_type))
def make_node(self, value: Variable, *conds: Tuple[Variable]): def make_node(self, value: Variable, *conds: Variable):
""" """
Parameters Parameters
......
import copy import copy
from typing import Tuple, Union from typing import Tuple
import numpy as np import numpy as np
...@@ -18,7 +18,7 @@ class MultinomialFromUniform(COp): ...@@ -18,7 +18,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype' TODO : need description for parameter 'odtype'
""" """
__props__: Union[Tuple[str], Tuple[str, str]] = ("odtype",) __props__: Tuple[str, ...] = ("odtype",)
def __init__(self, odtype): def __init__(self, odtype):
self.odtype = odtype self.odtype = odtype
......
...@@ -3998,7 +3998,7 @@ class Composite(ScalarOp, HasInnerGraph): ...@@ -3998,7 +3998,7 @@ class Composite(ScalarOp, HasInnerGraph):
""" """
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs") init_param: Tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
......
...@@ -137,7 +137,7 @@ try: ...@@ -137,7 +137,7 @@ try:
except ImportError: except ImportError:
pass pass
from typing import Tuple, Union from typing import Tuple
import pytensor.scalar import pytensor.scalar
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
...@@ -522,7 +522,7 @@ class GemmRelated(COp): ...@@ -522,7 +522,7 @@ class GemmRelated(COp):
""" """
__props__: Union[Tuple, Tuple[str]] = () __props__: Tuple[str, ...] = ()
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
# return cblas_header_text() # return cblas_header_text()
......
...@@ -1763,7 +1763,7 @@ def linspace(start, end, steps): ...@@ -1763,7 +1763,7 @@ def linspace(start, end, steps):
def broadcast_to( def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]] x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable, ...]]
) -> TensorVariable: ) -> TensorVariable:
"""Broadcast an array to a new shape. """Broadcast an array to a new shape.
......
from functools import partial from functools import partial
from typing import Tuple, Union from typing import Tuple
import numpy as np import numpy as np
...@@ -238,7 +238,7 @@ class Eig(Op): ...@@ -238,7 +238,7 @@ class Eig(Op):
""" """
_numop = staticmethod(np.linalg.eig) _numop = staticmethod(np.linalg.eig)
__props__: Union[Tuple, Tuple[str]] = () __props__: Tuple[str, ...] = ()
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
......
...@@ -28,7 +28,7 @@ def default_supp_shape_from_params( ...@@ -28,7 +28,7 @@ def default_supp_shape_from_params(
ndim_supp: int, ndim_supp: int,
dist_params: Sequence[Variable], dist_params: Sequence[Variable],
rep_param_idx: int = 0, rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None, param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]: ) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`. """Infer the dimensions for the output of a `RandomVariable`.
......
...@@ -544,7 +544,7 @@ _specify_shape = SpecifyShape() ...@@ -544,7 +544,7 @@ _specify_shape = SpecifyShape()
def specify_shape( def specify_shape(
x: Union[np.ndarray, Number, Variable], x: Union[np.ndarray, Number, Variable],
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]], shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType, ...]],
): ):
"""Specify a fixed shape for a `Variable`. """Specify a fixed shape for a `Variable`.
......
...@@ -85,7 +85,7 @@ invalid_tensor_types = ( ...@@ -85,7 +85,7 @@ invalid_tensor_types = (
def indices_from_subtensor( def indices_from_subtensor(
op_indices: Iterable[ScalarConstant], op_indices: Iterable[ScalarConstant],
idx_list: Optional[List[Union[Type, slice, Variable]]], idx_list: Optional[List[Union[Type, slice, Variable]]],
) -> Tuple[Union[slice, Variable]]: ) -> Union[slice, Variable]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created. """Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters Parameters
...@@ -411,7 +411,7 @@ def basic_shape(shape, indices): ...@@ -411,7 +411,7 @@ def basic_shape(shape, indices):
Parameters Parameters
---------- ----------
shape: Tuple[int] shape: Tuple[int, ...]
The shape of the array being indexed The shape of the array being indexed
indices: Sequence[Or[slice, NoneType]] indices: Sequence[Or[slice, NoneType]]
A sequence of basic indices used to index an array. A sequence of basic indices used to index an array.
...@@ -473,9 +473,9 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -473,9 +473,9 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
Parameters Parameters
---------- ----------
array_shape: Tuple[Variable] array_shape: Tuple[Variable, ...]
Shape of the array being indexed. Shape of the array being indexed.
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable]]]] indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable], ...]]]
Either the indices themselves or the shapes of each index--depending Either the indices themselves or the shapes of each index--depending
on the value of `indices_are_shapes`. on the value of `indices_are_shapes`.
indices_are_shapes: bool (Optional) indices_are_shapes: bool (Optional)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论