提交 55df19f0 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Add and fix type hints

The changes in this commit don't affect runtime code.
上级 2a371ab5
...@@ -14,6 +14,7 @@ from io import StringIO ...@@ -14,6 +14,7 @@ from io import StringIO
from itertools import chain from itertools import chain
from itertools import product as itertools_product from itertools import product as itertools_product
from logging import Logger from logging import Logger
from typing import Optional
from warnings import warn from warnings import warn
import numpy as np import numpy as np
...@@ -1362,14 +1363,14 @@ class _Linker(LocalLinker): ...@@ -1362,14 +1363,14 @@ class _Linker(LocalLinker):
self.maker = maker self.maker = maker
super().__init__(scheduler=schedule) super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None): def accept(self, fgraph, no_recycling: Optional[list] = None, profile=None):
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
assert type(self) is _Linker assert type(self) is _Linker
return type(self)(maker=self.maker).accept(fgraph, no_recycling, profile) return type(self)(maker=self.maker).accept(fgraph, no_recycling, profile)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling: list = no_recycling
return self return self
def make_all( def make_all(
...@@ -1401,7 +1402,7 @@ class _Linker(LocalLinker): ...@@ -1401,7 +1402,7 @@ class _Linker(LocalLinker):
# check_preallocated_output even on the output of the function. # check_preallocated_output even on the output of the function.
# no_recycling in individual thunks does not really matter, since # no_recycling in individual thunks does not really matter, since
# the function's outputs will always be freshly allocated. # the function's outputs will always be freshly allocated.
no_recycling = [] no_recycling: list = []
input_storage, output_storage, storage_map = map_storage( input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage_, output_storage_, storage_map fgraph, order, input_storage_, output_storage_, storage_map
...@@ -1492,6 +1493,7 @@ class _Linker(LocalLinker): ...@@ -1492,6 +1493,7 @@ class _Linker(LocalLinker):
# Use self.no_recycling (that was passed in accept()) to always # Use self.no_recycling (that was passed in accept()) to always
# use new memory storage when it is needed, in particular for the # use new memory storage when it is needed, in particular for the
# function's outputs. no_recycling_map will be used in f() below. # function's outputs. no_recycling_map will be used in f() below.
no_recycling_map: list = []
if self.no_recycling is True: if self.no_recycling is True:
no_recycling_map = list(storage_map.values()) no_recycling_map = list(storage_map.values())
no_recycling_map = difference(no_recycling_map, input_storage) no_recycling_map = difference(no_recycling_map, input_storage)
......
...@@ -36,6 +36,7 @@ from aesara.gpuarray.type import ( ...@@ -36,6 +36,7 @@ from aesara.gpuarray.type import (
ContextNotDefined, ContextNotDefined,
GpuArrayConstant, GpuArrayConstant,
GpuArrayType, GpuArrayType,
GpuContextType,
get_context, get_context,
gpu_context_type, gpu_context_type,
) )
...@@ -307,7 +308,7 @@ class GpuKernelBase: ...@@ -307,7 +308,7 @@ class GpuKernelBase:
""" """
params_type: Union[ParamsType, gpu_context_type] = gpu_context_type params_type: Union[ParamsType, GpuContextType] = gpu_context_type
def get_params(self, node): def get_params(self, node):
# Default implementation, suitable for most sub-classes. # Default implementation, suitable for most sub-classes.
......
...@@ -983,7 +983,7 @@ Py_INCREF(%(name)s); ...@@ -983,7 +983,7 @@ Py_INCREF(%(name)s);
Instance of :class:`GpuContextType` to use for the context_type Instance of :class:`GpuContextType` to use for the context_type
declaration of an operation. declaration of an operation.
""" """
gpu_context_type = GpuContextType() gpu_context_type: GpuContextType = GpuContextType()
# THIS WORKS But GpuArray instances don't compare equal to one # THIS WORKS But GpuArray instances don't compare equal to one
......
...@@ -104,7 +104,7 @@ class Apply(Node): ...@@ -104,7 +104,7 @@ class Apply(Node):
""" """
self.op = op self.op = op
self.inputs = [] self.inputs: List[Variable] = []
self.tag = Scratchpad() self.tag = Scratchpad()
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
...@@ -121,7 +121,7 @@ class Apply(Node): ...@@ -121,7 +121,7 @@ class Apply(Node):
raise TypeError( raise TypeError(
f"The 'inputs' argument to Apply must contain Variable instances, not {input}" f"The 'inputs' argument to Apply must contain Variable instances, not {input}"
) )
self.outputs = [] self.outputs: List[Variable] = []
# filter outputs to make sure each element is a Variable # filter outputs to make sure each element is a Variable
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
if isinstance(output, Variable): if isinstance(output, Variable):
...@@ -630,7 +630,7 @@ def walk( ...@@ -630,7 +630,7 @@ def walk(
bfs: bool = True, bfs: bool = True,
return_children: bool = False, return_children: bool = False,
hash_fn: Callable[[T], Hashable] = id, hash_fn: Callable[[T], Hashable] = id,
) -> Generator[T, None, Dict[T, List[T]]]: ) -> Generator[Union[T, Tuple[T, Optional[Sequence[T]]]], None, None]:
"""Walk through a graph, either breadth- or depth-first. """Walk through a graph, either breadth- or depth-first.
Parameters Parameters
...@@ -870,7 +870,7 @@ def clone_get_equiv( ...@@ -870,7 +870,7 @@ def clone_get_equiv(
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict[Variable, Variable]] = None,
): ) -> Dict[Variable, Variable]:
""" """
Return a dictionary that maps from `Variable` and `Apply` nodes in the Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
...@@ -1050,7 +1050,7 @@ def general_toposort( ...@@ -1050,7 +1050,7 @@ def general_toposort(
if deps_cache is None: if deps_cache is None:
raise ValueError("deps_cache cannot be None") raise ValueError("deps_cache cannot be None")
search_res: List[T, Optional[List[T]]] = list( search_res: List[Tuple[T, Optional[List[T]]]] = list(
walk(outputs, compute_deps_cache, bfs=False, return_children=True) walk(outputs, compute_deps_cache, bfs=False, return_children=True)
) )
...@@ -1088,8 +1088,8 @@ def general_toposort( ...@@ -1088,8 +1088,8 @@ def general_toposort(
def io_toposort( def io_toposort(
inputs: List[Variable], inputs: Iterable[Variable],
outputs: List[Variable], outputs: Iterable[Variable],
orderings: Optional[Dict[Apply, List[Apply]]] = None, orderings: Optional[Dict[Apply, List[Apply]]] = None,
clients: Optional[Dict[Variable, List[Variable]]] = None, clients: Optional[Dict[Variable, List[Variable]]] = None,
) -> List[Apply]: ) -> List[Apply]:
...@@ -1586,7 +1586,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1586,7 +1586,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def get_var_by_name( def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Tuple[Variable]: ) -> Tuple[Variable, ...]:
r"""Get variables in a graph using their names. r"""Get variables in a graph using their names.
Parameters Parameters
...@@ -1613,7 +1613,7 @@ def get_var_by_name( ...@@ -1613,7 +1613,7 @@ def get_var_by_name(
return res return res
results = () results: Tuple[Variable, ...] = ()
for var in walk(graphs, expand, False): for var in walk(graphs, expand, False):
if target_var_id == var.name or target_var_id == var.auto_name: if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,) results += (var,)
......
"""A container for specifying and manipulating a graph with distinct inputs and outputs.""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -47,9 +47,9 @@ class FunctionGraph(MetaObject): ...@@ -47,9 +47,9 @@ class FunctionGraph(MetaObject):
def __init__( def __init__(
self, self,
inputs: Optional[List[Variable]] = None, inputs: Optional[Sequence[Variable]] = None,
outputs: Optional[List[Variable]] = None, outputs: Optional[Sequence[Variable]] = None,
features: Optional[List[Feature]] = None, features: Optional[Sequence[Feature]] = None,
clone: bool = True, clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None, update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict[Variable, Variable]] = None,
...@@ -98,25 +98,25 @@ class FunctionGraph(MetaObject): ...@@ -98,25 +98,25 @@ class FunctionGraph(MetaObject):
inputs = [memo[i] for i in inputs] inputs = [memo[i] for i in inputs]
self.execute_callbacks_time = 0 self.execute_callbacks_time = 0
self.execute_callbacks_times = {} self.execute_callbacks_times: Dict[Feature, float] = {}
if features is None: if features is None:
features = [] features = []
self._features = [] self._features: List[Feature] = []
# All apply nodes in the subgraph defined by inputs and # All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field # outputs are cached in this field
self.apply_nodes = set() self.apply_nodes: Set[Apply] = set()
# Ditto for variable nodes. # Ditto for variable nodes.
# It must contain all fgraph.inputs and all apply_nodes # It must contain all fgraph.inputs and all apply_nodes
# outputs even if they aren't used in the graph. # outputs even if they aren't used in the graph.
self.variables = set() self.variables: Set[Variable] = set()
self.inputs = [] self.inputs: List[Variable] = []
self.outputs = list(outputs) self.outputs: List[Variable] = list(outputs)
self.clients = {} self.clients: Dict[Variable, List[Tuple[Union[Apply, str], int]]] = {}
for f in features: for f in features:
self.attach_feature(f) self.attach_feature(f)
...@@ -487,7 +487,7 @@ class FunctionGraph(MetaObject): ...@@ -487,7 +487,7 @@ class FunctionGraph(MetaObject):
node, i, new_var, reason=reason, import_missing=import_missing node, i, new_var, reason=reason, import_missing=import_missing
) )
def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> None: def replace_all(self, pairs: Iterable[Tuple[Variable, Variable]], **kwargs) -> None:
"""Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list.""" """Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list."""
for var, new_var in pairs: for var, new_var in pairs:
self.replace(var, new_var, **kwargs) self.replace(var, new_var, **kwargs)
...@@ -604,7 +604,7 @@ class FunctionGraph(MetaObject): ...@@ -604,7 +604,7 @@ class FunctionGraph(MetaObject):
""" """
assert isinstance(self._features, list) assert isinstance(self._features, list)
all_orderings = [] all_orderings: List[OrderedDict] = []
for feature in self._features: for feature in self._features:
if hasattr(feature, "orderings"): if hasattr(feature, "orderings"):
...@@ -630,7 +630,7 @@ class FunctionGraph(MetaObject): ...@@ -630,7 +630,7 @@ class FunctionGraph(MetaObject):
return all_orderings[0].copy() return all_orderings[0].copy()
else: else:
# If there is more than 1 ordering, combine them. # If there is more than 1 ordering, combine them.
ords = OrderedDict() ords: Dict[Apply, List[Apply]] = OrderedDict()
for orderings in all_orderings: for orderings in all_orderings:
for node, prereqs in orderings.items(): for node, prereqs in orderings.items():
ords.setdefault(node, []).extend(prereqs) ords.setdefault(node, []).extend(prereqs)
...@@ -695,7 +695,7 @@ class FunctionGraph(MetaObject): ...@@ -695,7 +695,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv( def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True self, check_integrity: bool = True, attach_feature: bool = True
) -> Union["FunctionGraph", Dict[Variable, Variable]]: ) -> Tuple["FunctionGraph", Dict[Variable, Variable]]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes. """Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters Parameters
......
...@@ -17,6 +17,7 @@ from typing import ( ...@@ -17,6 +17,7 @@ from typing import (
Any, Any,
Callable, Callable,
ClassVar, ClassVar,
Collection,
Dict, Dict,
List, List,
Optional, Optional,
...@@ -47,8 +48,8 @@ if TYPE_CHECKING: ...@@ -47,8 +48,8 @@ if TYPE_CHECKING:
from aesara.compile.function.types import Function from aesara.compile.function.types import Function
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
StorageMapType = List[Optional[List[Any]]] StorageMapType = Dict[Variable, List[Optional[List[Any]]]]
ComputeMapType = List[bool] ComputeMapType = Dict[Variable, List[bool]]
OutputStorageType = List[Optional[List[Any]]] OutputStorageType = List[Optional[List[Any]]]
ParamsInputType = Optional[Tuple[Any]] ParamsInputType = Optional[Tuple[Any]]
PerformMethodType = Callable[ PerformMethodType = Callable[
...@@ -613,7 +614,7 @@ class COp(Op, CLinkerOp): ...@@ -613,7 +614,7 @@ class COp(Op, CLinkerOp):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: bool, no_recycling: Collection[Apply],
) -> ThunkType: ) -> ThunkType:
"""Create a thunk for a C implementation. """Create a thunk for a C implementation.
...@@ -1073,7 +1074,7 @@ class ExternalCOp(COp): ...@@ -1073,7 +1074,7 @@ class ExternalCOp(COp):
f"No valid section marker was found in file {func_files[i]}" f"No valid section marker was found in file {func_files[i]}"
) )
def __get_op_params(self) -> Union[List[Text], List[Tuple[str, Any]]]: def __get_op_params(self) -> List[Tuple[str, Any]]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code. """Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the The names must be strings that are not a C keyword and the
...@@ -1089,9 +1090,10 @@ class ExternalCOp(COp): ...@@ -1089,9 +1090,10 @@ class ExternalCOp(COp):
associated to ``key``. associated to ``key``.
""" """
params: List[Tuple[str, Any]] = []
if hasattr(self, "params_type") and isinstance(self.params_type, ParamsType): if hasattr(self, "params_type") and isinstance(self.params_type, ParamsType):
wrapper = self.params_type wrapper = self.params_type
params = [("PARAMS_TYPE", wrapper.name)] params.append(("PARAMS_TYPE", wrapper.name))
for i in range(wrapper.length): for i in range(wrapper.length):
c_type = wrapper.types[i].c_element_type() c_type = wrapper.types[i].c_element_type()
if c_type: if c_type:
...@@ -1105,8 +1107,7 @@ class ExternalCOp(COp): ...@@ -1105,8 +1107,7 @@ class ExternalCOp(COp):
c_type, c_type,
) )
) )
return params return params
return []
def c_code_cache_version(self): def c_code_cache_version(self):
version = (hash(tuple(self.func_codes)),) version = (hash(tuple(self.func_codes)),)
......
...@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque ...@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial, reduce from functools import partial, reduce
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -1068,7 +1068,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -1068,7 +1068,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def local_optimizer( def local_optimizer(
tracks: Optional[List[Union[Op, type]]], tracks: Optional[Sequence[Union[Op, type]]],
inplace: bool = False, inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (), requirements: Optional[Tuple[type, ...]] = (),
): ):
...@@ -1184,7 +1184,10 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1184,7 +1184,10 @@ class LocalOptGroup(LocalOptimizer):
""" """
def __init__( def __init__(
self, *optimizers, apply_all_opts: bool = False, profile: bool = False self,
*optimizers: Sequence[Rewriter],
apply_all_opts: bool = False,
profile: bool = False,
): ):
""" """
...@@ -1205,7 +1208,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1205,7 +1208,7 @@ class LocalOptGroup(LocalOptimizer):
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
self.opts = optimizers self.opts: Sequence[Rewriter] = optimizers
assert isinstance(self.opts, tuple) assert isinstance(self.opts, tuple)
self.reentrant = any(getattr(opt, "reentrant", True) for opt in optimizers) self.reentrant = any(getattr(opt, "reentrant", True) for opt in optimizers)
...@@ -1217,10 +1220,10 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1217,10 +1220,10 @@ class LocalOptGroup(LocalOptimizer):
self.profile = profile self.profile = profile
if self.profile: if self.profile:
self.time_opts = {} self.time_opts: Dict[Rewriter, float] = {}
self.process_count = {} self.process_count: Dict[Rewriter, int] = {}
self.applied_true = {} self.applied_true: Dict[Rewriter, int] = {}
self.node_created = {} self.node_created: Dict[Rewriter, int] = {}
self.tracker = LocalOptTracker() self.tracker = LocalOptTracker()
......
...@@ -6,6 +6,8 @@ import re ...@@ -6,6 +6,8 @@ import re
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Optional, Text, TypeVar, Union from typing import Any, Optional, Text, TypeVar, Union
from typing_extensions import TypeAlias
from aesara.graph import utils from aesara.graph import utils
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.utils import MetaObject from aesara.graph.utils import MetaObject
...@@ -33,12 +35,12 @@ class Type(MetaObject): ...@@ -33,12 +35,12 @@ class Type(MetaObject):
""" """
Variable = Variable Variable: TypeAlias = Variable
""" """
The `Type` that will be created by a call to `Type.make_variable`. The `Type` that will be created by a call to `Type.make_variable`.
""" """
Constant = Constant Constant: TypeAlias = Constant
""" """
The `Type` that will be created by a call to `Type.make_constant`. The `Type` that will be created by a call to `Type.make_constant`.
""" """
...@@ -109,7 +111,7 @@ class Type(MetaObject): ...@@ -109,7 +111,7 @@ class Type(MetaObject):
storage: Any, storage: Any,
strict: bool = False, strict: bool = False,
allow_downcast: Optional[bool] = None, allow_downcast: Optional[bool] = None,
) -> None: ):
"""Return data or an appropriately wrapped/converted data by converting it in-place. """Return data or an appropriately wrapped/converted data by converting it in-place.
This method allows one to reuse old allocated memory. If this method This method allows one to reuse old allocated memory. If this method
......
...@@ -3,10 +3,12 @@ import sys ...@@ -3,10 +3,12 @@ import sys
import traceback import traceback
from abc import ABCMeta from abc import ABCMeta
from io import StringIO from io import StringIO
from typing import List from typing import List, Optional, Sequence, Tuple
def simple_extract_stack(f=None, limit=None, skips=None): def simple_extract_stack(
f=None, limit: Optional[int] = None, skips: Optional[Sequence[str]] = None
) -> List[Tuple[Optional[str], int, str, str]]:
"""This is traceback.extract_stack from python 2.7 with this change: """This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache. - Comment the update of the cache.
...@@ -33,7 +35,7 @@ def simple_extract_stack(f=None, limit=None, skips=None): ...@@ -33,7 +35,7 @@ def simple_extract_stack(f=None, limit=None, skips=None):
if limit is None: if limit is None:
if hasattr(sys, "tracebacklimit"): if hasattr(sys, "tracebacklimit"):
limit = sys.tracebacklimit limit = sys.tracebacklimit
trace = [] trace: List[Tuple[Optional[str], int, str, str]] = []
n = 0 n = 0
while f is not None and (limit is None or n < limit): while f is not None and (limit is None or n < limit):
lineno = f.f_lineno lineno = f.f_lineno
...@@ -67,7 +69,7 @@ def simple_extract_stack(f=None, limit=None, skips=None): ...@@ -67,7 +69,7 @@ def simple_extract_stack(f=None, limit=None, skips=None):
return trace return trace
def add_tag_trace(thing, user_line=None): def add_tag_trace(thing, user_line: Optional[int] = None):
"""Add tag.trace to a node or variable. """Add tag.trace to a node or variable.
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
......
...@@ -13,7 +13,7 @@ is a global operation with a scalar condition. ...@@ -13,7 +13,7 @@ is a global operation with a scalar condition.
import logging import logging
from copy import deepcopy from copy import deepcopy
from typing import List, Union from typing import List, Sequence, Union
import numpy as np import numpy as np
...@@ -311,7 +311,7 @@ def ifelse( ...@@ -311,7 +311,7 @@ def ifelse(
then_branch: Union[Variable, List[Variable]], then_branch: Union[Variable, List[Variable]],
else_branch: Union[Variable, List[Variable]], else_branch: Union[Variable, List[Variable]],
name: str = None, name: str = None,
) -> Union[Variable, List[Variable]]: ) -> Union[Variable, Sequence[Variable]]:
""" """
This function corresponds to an if statement, returning (and evaluating) This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or inputs in the ``then_branch`` if ``condition`` evaluates to True or
...@@ -340,13 +340,13 @@ def ifelse( ...@@ -340,13 +340,13 @@ def ifelse(
Returns Returns
======= =======
A list of aesara variables or a single variable (depending on the A sequence of aesara variables or a single variable (depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if nature of the ``then_branch`` and ``else_branch``). More exactly if
``then_branch`` and ``else_branch`` is a tensor, then ``then_branch`` and ``else_branch`` is a tensor, then
the return variable will be just a single variable, otherwise a the return variable will be just a single variable, otherwise a
list. The value returns correspond either to the values in the sequence. The value returns correspond either to the values in the
``then_branch`` or in the ``else_branch`` depending on the value of ``then_branch`` or in the ``else_branch`` depending on the value of
``cond``. ``condition``.
""" """
rval_type = None rval_type = None
......
...@@ -51,7 +51,7 @@ class Container: ...@@ -51,7 +51,7 @@ class Container:
def __init__( def __init__(
self, self,
r: MetaObject, r: MetaObject,
storage: Any, storage: List[Any],
*, *,
readonly: bool = False, readonly: bool = False,
strict: bool = False, strict: bool = False,
......
...@@ -148,7 +148,7 @@ class CLinkerObject: ...@@ -148,7 +148,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization.""" """Return a list of code snippets to be inserted in module initialization."""
return [] return []
def c_code_cache_version(self) -> Union[Tuple[int], Tuple]: def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]:
"""Return a tuple of integers indicating the version of this `Op`. """Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be cached An empty tuple indicates an "unversioned" `Op` that will not be cached
...@@ -211,7 +211,7 @@ class CLinkerOp(CLinkerObject): ...@@ -211,7 +211,7 @@ class CLinkerOp(CLinkerObject):
""" """
raise NotImplementedError() raise NotImplementedError()
def c_code_cache_version_apply(self, node: Apply) -> Tuple[int]: def c_code_cache_version_apply(self, node: Apply) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this `Op`. """Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be An empty tuple indicates an "unversioned" `Op` that will not be
......
...@@ -24,7 +24,7 @@ def map_storage( ...@@ -24,7 +24,7 @@ def map_storage(
order: Iterable[Apply], order: Iterable[Apply],
input_storage: Optional[List], input_storage: Optional[List],
output_storage: Optional[List], output_storage: Optional[List],
storage_map: Dict = None, storage_map: Optional[Dict] = None,
) -> Tuple[List, List, Dict]: ) -> Tuple[List, List, Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes. """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
......
...@@ -13,7 +13,7 @@ from contextlib import contextmanager ...@@ -13,7 +13,7 @@ from contextlib import contextmanager
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from io import IOBase, StringIO from io import IOBase, StringIO
from typing import Callable, Dict, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -168,9 +168,9 @@ def debugprint( ...@@ -168,9 +168,9 @@ def debugprint(
used_ids = dict() used_ids = dict()
results_to_print = [] results_to_print = []
profile_list = [] profile_list: List[Optional[Any]] = []
order = [] # Toposort order: List[Optional[List[Apply]]] = [] # Toposort
smap = [] # storage_map smap: List[Optional[StorageMapType]] = [] # storage_map
if isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
lobj = obj lobj = obj
...@@ -881,8 +881,8 @@ default_printer = DefaultPrinter() ...@@ -881,8 +881,8 @@ default_printer = DefaultPrinter()
class PPrinter(Printer): class PPrinter(Printer):
def __init__(self): def __init__(self):
self.printers = [] self.printers: List[Tuple[Union[Op, type, Callable], Printer]] = []
self.printers_dict = {} self.printers_dict: Dict[Union[Op, type, Callable], Printer] = {}
def assign(self, condition: Union[Op, type, Callable], printer: Printer): def assign(self, condition: Union[Op, type, Callable], printer: Printer):
if isinstance(condition, (Op, type)): if isinstance(condition, (Op, type)):
...@@ -999,7 +999,7 @@ else: ...@@ -999,7 +999,7 @@ else:
) )
pprint = PPrinter() pprint: PPrinter = PPrinter()
pprint.assign(lambda pstate, r: True, default_printer) pprint.assign(lambda pstate, r: True, default_printer)
pprint.assign(lambda pstate, r: isinstance(r, Constant), constant_printer) pprint.assign(lambda pstate, r: isinstance(r, Constant), constant_printer)
...@@ -1705,7 +1705,7 @@ def get_node_by_id( ...@@ -1705,7 +1705,7 @@ def get_node_by_id(
if isinstance(graphs, Variable): if isinstance(graphs, Variable):
graphs = (graphs,) graphs = (graphs,)
used_ids = dict() used_ids: Dict[Variable, str] = {}
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids) _ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids)
......
...@@ -19,6 +19,7 @@ from textwrap import dedent ...@@ -19,6 +19,7 @@ from textwrap import dedent
from typing import Dict, Mapping, Optional, Tuple, Type, Union from typing import Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np import numpy as np
from typing_extensions import TypeAlias
import aesara import aesara
from aesara import printing from aesara import printing
...@@ -106,47 +107,6 @@ def as_common_dtype(*vars): ...@@ -106,47 +107,6 @@ def as_common_dtype(*vars):
return (v.astype(dtype) for v in vars) return (v.astype(dtype) for v in vars)
def get_scalar_type(dtype):
"""
Return a Scalar(dtype) object.
This caches objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
def as_scalar(x, name=None):
from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to Scalar", type(x))
class NumpyAutocaster: class NumpyAutocaster:
""" """
This class is used to cast python ints and floats to numpy arrays. This class is used to cast python ints and floats to numpy arrays.
...@@ -314,12 +274,6 @@ def convert(x, dtype=None): ...@@ -314,12 +274,6 @@ def convert(x, dtype=None):
return x_ return x_
def constant(x, name=None, dtype=None):
x = convert(x, dtype=dtype)
assert x.ndim == 0
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
class Scalar(CType): class Scalar(CType):
""" """
...@@ -722,6 +676,21 @@ class Scalar(CType): ...@@ -722,6 +676,21 @@ class Scalar(CType):
return shape_info return shape_info
def get_scalar_type(dtype) -> Scalar:
"""
Return a Scalar(dtype) object.
This caches objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
# Register C code for ViewOp on Scalars. # Register C code for ViewOp on Scalars.
aesara.compile.register_view_op_c_code( aesara.compile.register_view_op_c_code(
Scalar, Scalar,
...@@ -732,30 +701,31 @@ aesara.compile.register_view_op_c_code( ...@@ -732,30 +701,31 @@ aesara.compile.register_view_op_c_code(
) )
bool = get_scalar_type("bool") bool: Scalar = get_scalar_type("bool")
int8 = get_scalar_type("int8") int8: Scalar = get_scalar_type("int8")
int16 = get_scalar_type("int16") int16: Scalar = get_scalar_type("int16")
int32 = get_scalar_type("int32") int32: Scalar = get_scalar_type("int32")
int64 = get_scalar_type("int64") int64: Scalar = get_scalar_type("int64")
uint8 = get_scalar_type("uint8") uint8: Scalar = get_scalar_type("uint8")
uint16 = get_scalar_type("uint16") uint16: Scalar = get_scalar_type("uint16")
uint32 = get_scalar_type("uint32") uint32: Scalar = get_scalar_type("uint32")
uint64 = get_scalar_type("uint64") uint64: Scalar = get_scalar_type("uint64")
float16 = get_scalar_type("float16") float16: Scalar = get_scalar_type("float16")
float32 = get_scalar_type("float32") float32: Scalar = get_scalar_type("float32")
float64 = get_scalar_type("float64") float64: Scalar = get_scalar_type("float64")
complex64 = get_scalar_type("complex64") complex64: Scalar = get_scalar_type("complex64")
complex128 = get_scalar_type("complex128") complex128: Scalar = get_scalar_type("complex128")
int_types = int8, int16, int32, int64 _ScalarTypes: TypeAlias = Tuple[Scalar, ...]
uint_types = uint8, uint16, uint32, uint64 int_types: _ScalarTypes = (int8, int16, int32, int64)
float_types = float16, float32, float64 uint_types: _ScalarTypes = (uint8, uint16, uint32, uint64)
complex_types = complex64, complex128 float_types: _ScalarTypes = (float16, float32, float64)
complex_types: _ScalarTypes = (complex64, complex128)
integer_types = int_types + uint_types
discrete_types = (bool,) + integer_types integer_types: _ScalarTypes = int_types + uint_types
continuous_types = float_types + complex_types discrete_types: _ScalarTypes = (bool,) + integer_types
all_types = discrete_types + continuous_types continuous_types: _ScalarTypes = float_types + complex_types
all_types: _ScalarTypes = discrete_types + continuous_types
discrete_dtypes = tuple(t.dtype for t in discrete_types) discrete_dtypes = tuple(t.dtype for t in discrete_types)
...@@ -885,6 +855,38 @@ class ScalarConstant(ScalarVariable, Constant): ...@@ -885,6 +855,38 @@ class ScalarConstant(ScalarVariable, Constant):
Scalar.Constant = ScalarConstant Scalar.Constant = ScalarConstant
def constant(x, name=None, dtype=None) -> ScalarConstant:
x = convert(x, dtype=dtype)
assert x.ndim == 0
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
def as_scalar(x, name=None) -> ScalarConstant:
from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to Scalar", type(x))
# Easy constructors # Easy constructors
ints = apply_across_args(int64) ints = apply_across_args(int64)
...@@ -1276,9 +1278,9 @@ class BinaryScalarOp(ScalarOp): ...@@ -1276,9 +1278,9 @@ class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields: # One may define in subclasses the following fields:
# - `commutative`: whether op(a, b) == op(b, a) # - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c)) # - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative: Optional[bool] = None commutative: Optional[builtins.bool] = None
associative: Optional[bool] = None associative: Optional[builtins.bool] = None
identity: Optional[bool] = None identity: Optional[builtins.bool] = None
""" """
For an associative operation, the identity object corresponds to the neutral For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication, element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
...@@ -2501,20 +2503,20 @@ class Cast(UnaryScalarOp): ...@@ -2501,20 +2503,20 @@ class Cast(UnaryScalarOp):
return s return s
convert_to_bool = Cast(bool, name="convert_to_bool") convert_to_bool: Cast = Cast(bool, name="convert_to_bool")
convert_to_int8 = Cast(int8, name="convert_to_int8") convert_to_int8: Cast = Cast(int8, name="convert_to_int8")
convert_to_int16 = Cast(int16, name="convert_to_int16") convert_to_int16: Cast = Cast(int16, name="convert_to_int16")
convert_to_int32 = Cast(int32, name="convert_to_int32") convert_to_int32: Cast = Cast(int32, name="convert_to_int32")
convert_to_int64 = Cast(int64, name="convert_to_int64") convert_to_int64: Cast = Cast(int64, name="convert_to_int64")
convert_to_uint8 = Cast(uint8, name="convert_to_uint8") convert_to_uint8: Cast = Cast(uint8, name="convert_to_uint8")
convert_to_uint16 = Cast(uint16, name="convert_to_uint16") convert_to_uint16: Cast = Cast(uint16, name="convert_to_uint16")
convert_to_uint32 = Cast(uint32, name="convert_to_uint32") convert_to_uint32: Cast = Cast(uint32, name="convert_to_uint32")
convert_to_uint64 = Cast(uint64, name="convert_to_uint64") convert_to_uint64: Cast = Cast(uint64, name="convert_to_uint64")
convert_to_float16 = Cast(float16, name="convert_to_float16") convert_to_float16: Cast = Cast(float16, name="convert_to_float16")
convert_to_float32 = Cast(float32, name="convert_to_float32") convert_to_float32: Cast = Cast(float32, name="convert_to_float32")
convert_to_float64 = Cast(float64, name="convert_to_float64") convert_to_float64: Cast = Cast(float64, name="convert_to_float64")
convert_to_complex64 = Cast(complex64, name="convert_to_complex64") convert_to_complex64: Cast = Cast(complex64, name="convert_to_complex64")
convert_to_complex128 = Cast(complex128, name="convert_to_complex128") convert_to_complex128: Cast = Cast(complex128, name="convert_to_complex128")
_cast_mapping = { _cast_mapping = {
"bool": convert_to_bool, "bool": convert_to_bool,
......
...@@ -683,7 +683,7 @@ def push_out_inner_vars( ...@@ -683,7 +683,7 @@ def push_out_inner_vars(
outer_vars = [None] * len(inner_vars) outer_vars = [None] * len(inner_vars)
new_scan_node = old_scan_node new_scan_node = old_scan_node
new_scan_args = old_scan_args new_scan_args = old_scan_args
replacements = {} replacements: Dict[Variable, Variable] = {}
# For the inner_vars that already exist in the outer graph, # For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them # simply obtain a reference to them
......
...@@ -1016,7 +1016,7 @@ class ScanArgs: ...@@ -1016,7 +1016,7 @@ class ScanArgs:
def find_among_fields( def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[Tuple[str, int, int, int]]: ) -> Optional[FieldInfo]:
"""Find the type and indices of the field containing a given element. """Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element. NOTE: This only returns the *first* field containing the given element.
...@@ -1158,14 +1158,14 @@ class ScanArgs: ...@@ -1158,14 +1158,14 @@ class ScanArgs:
def remove_from_fields( def remove_from_fields(
self, i: Variable, rm_dependents: bool = True self, i: Variable, rm_dependents: bool = True
) -> List[FieldInfo]: ) -> List[Tuple[Variable, Optional[FieldInfo]]]:
if rm_dependents: if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i} vars_to_remove = self.get_dependent_nodes(i) | {i}
else: else:
vars_to_remove = {i} vars_to_remove = {i}
rm_info = [] rm_info: List[Tuple[Variable, Optional[FieldInfo]]] = []
for v in vars_to_remove: for v in vars_to_remove:
dependent_rm_info = self._remove_from_fields(v) dependent_rm_info = self._remove_from_fields(v)
rm_info.append((v, dependent_rm_info)) rm_info.append((v, dependent_rm_info))
......
...@@ -9,7 +9,7 @@ from aesara.graph.op import Op ...@@ -9,7 +9,7 @@ from aesara.graph.op import Op
def as_tensor_variable( def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Variable: ) -> "TensorVariable":
"""Convert `x` into an equivalent `TensorVariable`. """Convert `x` into an equivalent `TensorVariable`.
This function can be used to turn ndarrays, numbers, `Scalar` instances, This function can be used to turn ndarrays, numbers, `Scalar` instances,
...@@ -45,7 +45,7 @@ def as_tensor_variable( ...@@ -45,7 +45,7 @@ def as_tensor_variable(
@singledispatch @singledispatch
def _as_tensor_variable( def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs x, name: Optional[str], ndim: Optional[int], **kwargs
) -> NoReturn: ) -> "TensorVariable":
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.") raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
...@@ -80,7 +80,7 @@ def get_vector_length(v: Any): ...@@ -80,7 +80,7 @@ def get_vector_length(v: Any):
@singledispatch @singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable) -> NoReturn: def _get_vector_length(op: Union[Op, Variable], var: Variable):
"""`Op`-based dispatch for `get_vector_length`.""" """`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
......
...@@ -184,7 +184,7 @@ def _as_tensor_bool(x, name, ndim, **kwargs): ...@@ -184,7 +184,7 @@ def _as_tensor_bool(x, name, ndim, **kwargs):
as_tensor = as_tensor_variable as_tensor = as_tensor_variable
def constant(x, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant:
"""Return a `TensorConstant` with value `x`. """Return a `TensorConstant` with value `x`.
Raises Raises
...@@ -795,7 +795,7 @@ register_rebroadcast_c_code( ...@@ -795,7 +795,7 @@ register_rebroadcast_c_code(
# to be removed as we get the epydoc routine-documenting thing going # to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924 # -JB 20080924
def _conversion(real_value, name): def _conversion(real_value: Op, name: str) -> Op:
__oplist_tag(real_value, "casting") __oplist_tag(real_value, "casting")
real_value.__module__ = "tensor.basic" real_value.__module__ = "tensor.basic"
pprint.assign(real_value, printing.FunctionPrinter([name])) pprint.assign(real_value, printing.FunctionPrinter([name]))
...@@ -807,46 +807,50 @@ def _conversion(real_value, name): ...@@ -807,46 +807,50 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the # what types you are casting to what. That logic is implemented by the
# `cast()` function below. # `cast()` function below.
_convert_to_bool = _conversion(Elemwise(aes.convert_to_bool), "bool") _convert_to_bool: Elemwise = _conversion(Elemwise(aes.convert_to_bool), "bool")
"""Cast to boolean""" """Cast to boolean"""
_convert_to_int8 = _conversion(Elemwise(aes.convert_to_int8), "int8") _convert_to_int8: Elemwise = _conversion(Elemwise(aes.convert_to_int8), "int8")
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
_convert_to_int16 = _conversion(Elemwise(aes.convert_to_int16), "int16") _convert_to_int16: Elemwise = _conversion(Elemwise(aes.convert_to_int16), "int16")
"""Cast to 16-bit integer""" """Cast to 16-bit integer"""
_convert_to_int32 = _conversion(Elemwise(aes.convert_to_int32), "int32") _convert_to_int32: Elemwise = _conversion(Elemwise(aes.convert_to_int32), "int32")
"""Cast to 32-bit integer""" """Cast to 32-bit integer"""
_convert_to_int64 = _conversion(Elemwise(aes.convert_to_int64), "int64") _convert_to_int64: Elemwise = _conversion(Elemwise(aes.convert_to_int64), "int64")
"""Cast to 64-bit integer""" """Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(Elemwise(aes.convert_to_uint8), "uint8") _convert_to_uint8: Elemwise = _conversion(Elemwise(aes.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer""" """Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(Elemwise(aes.convert_to_uint16), "uint16") _convert_to_uint16: Elemwise = _conversion(Elemwise(aes.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer""" """Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(Elemwise(aes.convert_to_uint32), "uint32") _convert_to_uint32: Elemwise = _conversion(Elemwise(aes.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer""" """Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(Elemwise(aes.convert_to_uint64), "uint64") _convert_to_uint64: Elemwise = _conversion(Elemwise(aes.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer""" """Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(Elemwise(aes.convert_to_float16), "float16") _convert_to_float16: Elemwise = _conversion(Elemwise(aes.convert_to_float16), "float16")
"""Cast to half-precision floating point""" """Cast to half-precision floating point"""
_convert_to_float32 = _conversion(Elemwise(aes.convert_to_float32), "float32") _convert_to_float32: Elemwise = _conversion(Elemwise(aes.convert_to_float32), "float32")
"""Cast to single-precision floating point""" """Cast to single-precision floating point"""
_convert_to_float64 = _conversion(Elemwise(aes.convert_to_float64), "float64") _convert_to_float64: Elemwise = _conversion(Elemwise(aes.convert_to_float64), "float64")
"""Cast to double-precision floating point""" """Cast to double-precision floating point"""
_convert_to_complex64 = _conversion(Elemwise(aes.convert_to_complex64), "complex64") _convert_to_complex64: Elemwise = _conversion(
Elemwise(aes.convert_to_complex64), "complex64"
)
"""Cast to single-precision complex""" """Cast to single-precision complex"""
_convert_to_complex128 = _conversion(Elemwise(aes.convert_to_complex128), "complex128") _convert_to_complex128: Elemwise = _conversion(
Elemwise(aes.convert_to_complex128), "complex128"
)
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
...@@ -867,7 +871,7 @@ _cast_mapping = { ...@@ -867,7 +871,7 @@ _cast_mapping = {
} }
def cast(x, dtype): def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable:
"""Symbolically cast `x` to a Tensor of type `dtype`.""" """Symbolically cast `x` to a Tensor of type `dtype`."""
if isinstance(dtype, str) and dtype == "floatX": if isinstance(dtype, str) and dtype == "floatX":
......
from copy import copy from copy import copy
from typing import List, Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -27,9 +27,9 @@ from aesara.tensor.var import TensorVariable ...@@ -27,9 +27,9 @@ from aesara.tensor.var import TensorVariable
def default_supp_shape_from_params( def default_supp_shape_from_params(
ndim_supp: int, ndim_supp: int,
dist_params: Sequence[Variable], dist_params: Sequence[Variable],
rep_param_idx: Optional[int] = 0, rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None, param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None,
) -> Tuple[int, ...]: ) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`. """Infer the dimensions for the output of a `RandomVariable`.
This is a function that derives a random variable's support This is a function that derives a random variable's support
...@@ -171,10 +171,10 @@ class RandomVariable(Op): ...@@ -171,10 +171,10 @@ class RandomVariable(Op):
def _infer_shape( def _infer_shape(
self, self,
size: Tuple[TensorVariable], size: TensorVariable,
dist_params: List[TensorVariable], dist_params: Sequence[TensorVariable],
param_shapes: Optional[List[Tuple[TensorVariable]]] = None, param_shapes: Optional[Sequence[Tuple[Variable, ...]]] = None,
) -> Tuple[ScalarVariable]: ) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Compute the output shape given the size and distribution parameters. """Compute the output shape given the size and distribution parameters.
Parameters Parameters
......
from collections.abc import Sequence from collections.abc import Sequence
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
from typing import Optional, Union
import numpy as np import numpy as np
...@@ -111,7 +112,9 @@ def broadcast_params(params, ndims_params): ...@@ -111,7 +112,9 @@ def broadcast_params(params, ndims_params):
return bcast_params return bcast_params
def normalize_size_param(size): def normalize_size_param(
size: Optional[Union[int, np.ndarray, Variable, Sequence]]
) -> Variable:
"""Create an Aesara value for a ``RandomVariable`` ``size`` parameter.""" """Create an Aesara value for a ``RandomVariable`` ``size`` parameter."""
if size is None: if size is None:
size = constant([], dtype="int64") size = constant([], dtype="int64")
......
...@@ -16,7 +16,7 @@ from aesara.tensor import basic as at ...@@ -16,7 +16,7 @@ from aesara.tensor import basic as at
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant, TensorVariable
def register_shape_c_code(type, code, version=()): def register_shape_c_code(type, code, version=()):
...@@ -155,7 +155,7 @@ def _get_vector_length_Shape(op, var): ...@@ -155,7 +155,7 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim return var.owner.inputs[0].type.ndim
def shape_tuple(x: Variable) -> Tuple[Variable]: def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
"""Get a tuple of symbolic shape values. """Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever This will return a `ScalarConstant` with the value ``1`` wherever
......
...@@ -135,8 +135,8 @@ def as_index_constant( ...@@ -135,8 +135,8 @@ def as_index_constant(
def as_index_literal( def as_index_literal(
idx: Union[Variable, slice, type(np.newaxis)] idx: Optional[Union[Variable, slice]]
) -> Union[int, slice, type(np.newaxis)]: ) -> Optional[Union[int, slice]]:
"""Convert a symbolic index element to its Python equivalent. """Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant` This is like the inverse of `as_index_constant`
......
...@@ -48,6 +48,7 @@ dependencies: ...@@ -48,6 +48,7 @@ dependencies:
# developer tools # developer tools
- pre-commit - pre-commit
- packaging - packaging
- typing_extensions
# optional # optional
- sympy - sympy
- cython - cython
...@@ -52,6 +52,7 @@ install_requires = [ ...@@ -52,6 +52,7 @@ install_requires = [
"logical-unification", "logical-unification",
"miniKanren", "miniKanren",
"cons", "cons",
"typing_extensions",
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论