提交 55df19f0 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Add and fix type hints

The changes in this commit don't affect runtime code.
上级 2a371ab5
......@@ -14,6 +14,7 @@ from io import StringIO
from itertools import chain
from itertools import product as itertools_product
from logging import Logger
from typing import Optional
from warnings import warn
import numpy as np
......@@ -1362,14 +1363,14 @@ class _Linker(LocalLinker):
self.maker = maker
super().__init__(scheduler=schedule)
def accept(self, fgraph, no_recycling=None, profile=None):
def accept(self, fgraph, no_recycling: Optional[list] = None, profile=None):
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
assert type(self) is _Linker
return type(self)(maker=self.maker).accept(fgraph, no_recycling, profile)
self.fgraph = fgraph
self.no_recycling = no_recycling
self.no_recycling: list = no_recycling
return self
def make_all(
......@@ -1401,7 +1402,7 @@ class _Linker(LocalLinker):
# check_preallocated_output even on the output of the function.
# no_recycling in individual thunks does not really matter, since
# the function's outputs will always be freshly allocated.
no_recycling = []
no_recycling: list = []
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage_, output_storage_, storage_map
......@@ -1492,6 +1493,7 @@ class _Linker(LocalLinker):
# Use self.no_recycling (that was passed in accept()) to always
# use new memory storage when it is needed, in particular for the
# function's outputs. no_recycling_map will be used in f() below.
no_recycling_map: list = []
if self.no_recycling is True:
no_recycling_map = list(storage_map.values())
no_recycling_map = difference(no_recycling_map, input_storage)
......
......@@ -36,6 +36,7 @@ from aesara.gpuarray.type import (
ContextNotDefined,
GpuArrayConstant,
GpuArrayType,
GpuContextType,
get_context,
gpu_context_type,
)
......@@ -307,7 +308,7 @@ class GpuKernelBase:
"""
params_type: Union[ParamsType, gpu_context_type] = gpu_context_type
params_type: Union[ParamsType, GpuContextType] = gpu_context_type
def get_params(self, node):
# Default implementation, suitable for most sub-classes.
......
......@@ -983,7 +983,7 @@ Py_INCREF(%(name)s);
Instance of :class:`GpuContextType` to use for the context_type
declaration of an operation.
"""
gpu_context_type = GpuContextType()
gpu_context_type: GpuContextType = GpuContextType()
# THIS WORKS But GpuArray instances don't compare equal to one
......
......@@ -104,7 +104,7 @@ class Apply(Node):
"""
self.op = op
self.inputs = []
self.inputs: List[Variable] = []
self.tag = Scratchpad()
if not isinstance(inputs, (list, tuple)):
......@@ -121,7 +121,7 @@ class Apply(Node):
raise TypeError(
f"The 'inputs' argument to Apply must contain Variable instances, not {input}"
)
self.outputs = []
self.outputs: List[Variable] = []
# filter outputs to make sure each element is a Variable
for i, output in enumerate(outputs):
if isinstance(output, Variable):
......@@ -630,7 +630,7 @@ def walk(
bfs: bool = True,
return_children: bool = False,
hash_fn: Callable[[T], Hashable] = id,
) -> Generator[T, None, Dict[T, List[T]]]:
) -> Generator[Union[T, Tuple[T, Optional[Sequence[T]]]], None, None]:
"""Walk through a graph, either breadth- or depth-first.
Parameters
......@@ -870,7 +870,7 @@ def clone_get_equiv(
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None,
):
) -> Dict[Variable, Variable]:
"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
......@@ -1050,7 +1050,7 @@ def general_toposort(
if deps_cache is None:
raise ValueError("deps_cache cannot be None")
search_res: List[T, Optional[List[T]]] = list(
search_res: List[Tuple[T, Optional[List[T]]]] = list(
walk(outputs, compute_deps_cache, bfs=False, return_children=True)
)
......@@ -1088,8 +1088,8 @@ def general_toposort(
def io_toposort(
inputs: List[Variable],
outputs: List[Variable],
inputs: Iterable[Variable],
outputs: Iterable[Variable],
orderings: Optional[Dict[Apply, List[Apply]]] = None,
clients: Optional[Dict[Variable, List[Variable]]] = None,
) -> List[Apply]:
......@@ -1586,7 +1586,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> Tuple[Variable]:
) -> Tuple[Variable, ...]:
r"""Get variables in a graph using their names.
Parameters
......@@ -1613,7 +1613,7 @@ def get_var_by_name(
return res
results = ()
results: Tuple[Variable, ...] = ()
for var in walk(graphs, expand, False):
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
......
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
import aesara
from aesara.configdefaults import config
......@@ -47,9 +47,9 @@ class FunctionGraph(MetaObject):
def __init__(
self,
inputs: Optional[List[Variable]] = None,
outputs: Optional[List[Variable]] = None,
features: Optional[List[Feature]] = None,
inputs: Optional[Sequence[Variable]] = None,
outputs: Optional[Sequence[Variable]] = None,
features: Optional[Sequence[Feature]] = None,
clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None,
......@@ -98,25 +98,25 @@ class FunctionGraph(MetaObject):
inputs = [memo[i] for i in inputs]
self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
self.execute_callbacks_times: Dict[Feature, float] = {}
if features is None:
features = []
self._features = []
self._features: List[Feature] = []
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self.apply_nodes = set()
self.apply_nodes: Set[Apply] = set()
# Ditto for variable nodes.
# It must contain all fgraph.inputs and all apply_nodes
# outputs even if they aren't used in the graph.
self.variables = set()
self.variables: Set[Variable] = set()
self.inputs = []
self.outputs = list(outputs)
self.clients = {}
self.inputs: List[Variable] = []
self.outputs: List[Variable] = list(outputs)
self.clients: Dict[Variable, List[Tuple[Union[Apply, str], int]]] = {}
for f in features:
self.attach_feature(f)
......@@ -487,7 +487,7 @@ class FunctionGraph(MetaObject):
node, i, new_var, reason=reason, import_missing=import_missing
)
def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> None:
def replace_all(self, pairs: Iterable[Tuple[Variable, Variable]], **kwargs) -> None:
"""Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, **kwargs)
......@@ -604,7 +604,7 @@ class FunctionGraph(MetaObject):
"""
assert isinstance(self._features, list)
all_orderings = []
all_orderings: List[OrderedDict] = []
for feature in self._features:
if hasattr(feature, "orderings"):
......@@ -630,7 +630,7 @@ class FunctionGraph(MetaObject):
return all_orderings[0].copy()
else:
# If there is more than 1 ordering, combine them.
ords = OrderedDict()
ords: Dict[Apply, List[Apply]] = OrderedDict()
for orderings in all_orderings:
for node, prereqs in orderings.items():
ords.setdefault(node, []).extend(prereqs)
......@@ -695,7 +695,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Union["FunctionGraph", Dict[Variable, Variable]]:
) -> Tuple["FunctionGraph", Dict[Variable, Variable]]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters
......
......@@ -17,6 +17,7 @@ from typing import (
Any,
Callable,
ClassVar,
Collection,
Dict,
List,
Optional,
......@@ -47,8 +48,8 @@ if TYPE_CHECKING:
from aesara.compile.function.types import Function
from aesara.graph.fg import FunctionGraph
StorageMapType = List[Optional[List[Any]]]
ComputeMapType = List[bool]
StorageMapType = Dict[Variable, List[Optional[List[Any]]]]
ComputeMapType = Dict[Variable, List[bool]]
OutputStorageType = List[Optional[List[Any]]]
ParamsInputType = Optional[Tuple[Any]]
PerformMethodType = Callable[
......@@ -613,7 +614,7 @@ class COp(Op, CLinkerOp):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: bool,
no_recycling: Collection[Apply],
) -> ThunkType:
"""Create a thunk for a C implementation.
......@@ -1073,7 +1074,7 @@ class ExternalCOp(COp):
f"No valid section marker was found in file {func_files[i]}"
)
def __get_op_params(self) -> Union[List[Text], List[Tuple[str, Any]]]:
def __get_op_params(self) -> List[Tuple[str, Any]]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the
......@@ -1089,9 +1090,10 @@ class ExternalCOp(COp):
associated to ``key``.
"""
params: List[Tuple[str, Any]] = []
if hasattr(self, "params_type") and isinstance(self.params_type, ParamsType):
wrapper = self.params_type
params = [("PARAMS_TYPE", wrapper.name)]
params.append(("PARAMS_TYPE", wrapper.name))
for i in range(wrapper.length):
c_type = wrapper.types[i].c_element_type()
if c_type:
......@@ -1105,8 +1107,7 @@ class ExternalCOp(COp):
c_type,
)
)
return params
return []
return params
def c_code_cache_version(self):
version = (hash(tuple(self.func_codes)),)
......
......@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
from collections.abc import Iterable
from functools import partial, reduce
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import aesara
from aesara.configdefaults import config
......@@ -1068,7 +1068,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def local_optimizer(
tracks: Optional[List[Union[Op, type]]],
tracks: Optional[Sequence[Union[Op, type]]],
inplace: bool = False,
requirements: Optional[Tuple[type, ...]] = (),
):
......@@ -1184,7 +1184,10 @@ class LocalOptGroup(LocalOptimizer):
"""
def __init__(
self, *optimizers, apply_all_opts: bool = False, profile: bool = False
self,
*optimizers: Sequence[Rewriter],
apply_all_opts: bool = False,
profile: bool = False,
):
"""
......@@ -1205,7 +1208,7 @@ class LocalOptGroup(LocalOptimizer):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
self.opts = optimizers
self.opts: Sequence[Rewriter] = optimizers
assert isinstance(self.opts, tuple)
self.reentrant = any(getattr(opt, "reentrant", True) for opt in optimizers)
......@@ -1217,10 +1220,10 @@ class LocalOptGroup(LocalOptimizer):
self.profile = profile
if self.profile:
self.time_opts = {}
self.process_count = {}
self.applied_true = {}
self.node_created = {}
self.time_opts: Dict[Rewriter, float] = {}
self.process_count: Dict[Rewriter, int] = {}
self.applied_true: Dict[Rewriter, int] = {}
self.node_created: Dict[Rewriter, int] = {}
self.tracker = LocalOptTracker()
......
......@@ -6,6 +6,8 @@ import re
from abc import abstractmethod
from typing import Any, Optional, Text, TypeVar, Union
from typing_extensions import TypeAlias
from aesara.graph import utils
from aesara.graph.basic import Constant, Variable
from aesara.graph.utils import MetaObject
......@@ -33,12 +35,12 @@ class Type(MetaObject):
"""
Variable = Variable
Variable: TypeAlias = Variable
"""
The `Type` that will be created by a call to `Type.make_variable`.
"""
Constant = Constant
Constant: TypeAlias = Constant
"""
The `Type` that will be created by a call to `Type.make_constant`.
"""
......@@ -109,7 +111,7 @@ class Type(MetaObject):
storage: Any,
strict: bool = False,
allow_downcast: Optional[bool] = None,
) -> None:
):
"""Return data or an appropriately wrapped/converted data by converting it in-place.
This method allows one to reuse old allocated memory. If this method
......
......@@ -3,10 +3,12 @@ import sys
import traceback
from abc import ABCMeta
from io import StringIO
from typing import List
from typing import List, Optional, Sequence, Tuple
def simple_extract_stack(f=None, limit=None, skips=None):
def simple_extract_stack(
f=None, limit: Optional[int] = None, skips: Optional[Sequence[str]] = None
) -> List[Tuple[Optional[str], int, str, str]]:
"""This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache.
......@@ -33,7 +35,7 @@ def simple_extract_stack(f=None, limit=None, skips=None):
if limit is None:
if hasattr(sys, "tracebacklimit"):
limit = sys.tracebacklimit
trace = []
trace: List[Tuple[Optional[str], int, str, str]] = []
n = 0
while f is not None and (limit is None or n < limit):
lineno = f.f_lineno
......@@ -67,7 +69,7 @@ def simple_extract_stack(f=None, limit=None, skips=None):
return trace
def add_tag_trace(thing, user_line=None):
def add_tag_trace(thing, user_line: Optional[int] = None):
"""Add tag.trace to a node or variable.
The argument is returned after being affected (inplace).
......
......@@ -13,7 +13,7 @@ is a global operation with a scalar condition.
import logging
from copy import deepcopy
from typing import List, Union
from typing import List, Sequence, Union
import numpy as np
......@@ -311,7 +311,7 @@ def ifelse(
then_branch: Union[Variable, List[Variable]],
else_branch: Union[Variable, List[Variable]],
name: str = None,
) -> Union[Variable, List[Variable]]:
) -> Union[Variable, Sequence[Variable]]:
"""
This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or
......@@ -340,13 +340,13 @@ def ifelse(
Returns
=======
A list of aesara variables or a single variable (depending on the
A sequence of aesara variables or a single variable (depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if
``then_branch`` and ``else_branch`` is a tensor, then
the return variable will be just a single variable, otherwise a
list. The value returns correspond either to the values in the
sequence. The value returns correspond either to the values in the
``then_branch`` or in the ``else_branch`` depending on the value of
``cond``.
``condition``.
"""
rval_type = None
......
......@@ -51,7 +51,7 @@ class Container:
def __init__(
self,
r: MetaObject,
storage: Any,
storage: List[Any],
*,
readonly: bool = False,
strict: bool = False,
......
......@@ -148,7 +148,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization."""
return []
def c_code_cache_version(self) -> Union[Tuple[int], Tuple]:
def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]:
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be cached
......@@ -211,7 +211,7 @@ class CLinkerOp(CLinkerObject):
"""
raise NotImplementedError()
def c_code_cache_version_apply(self, node: Apply) -> Tuple[int]:
def c_code_cache_version_apply(self, node: Apply) -> Tuple[int, ...]:
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an "unversioned" `Op` that will not be
......
......@@ -24,7 +24,7 @@ def map_storage(
order: Iterable[Apply],
input_storage: Optional[List],
output_storage: Optional[List],
storage_map: Dict = None,
storage_map: Optional[Dict] = None,
) -> Tuple[List, List, Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
......
......@@ -13,7 +13,7 @@ from contextlib import contextmanager
from copy import copy
from functools import reduce
from io import IOBase, StringIO
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
......@@ -168,9 +168,9 @@ def debugprint(
used_ids = dict()
results_to_print = []
profile_list = []
order = [] # Toposort
smap = [] # storage_map
profile_list: List[Optional[Any]] = []
order: List[Optional[List[Apply]]] = [] # Toposort
smap: List[Optional[StorageMapType]] = [] # storage_map
if isinstance(obj, (list, tuple, set)):
lobj = obj
......@@ -881,8 +881,8 @@ default_printer = DefaultPrinter()
class PPrinter(Printer):
def __init__(self):
self.printers = []
self.printers_dict = {}
self.printers: List[Tuple[Union[Op, type, Callable], Printer]] = []
self.printers_dict: Dict[Union[Op, type, Callable], Printer] = {}
def assign(self, condition: Union[Op, type, Callable], printer: Printer):
if isinstance(condition, (Op, type)):
......@@ -999,7 +999,7 @@ else:
)
pprint = PPrinter()
pprint: PPrinter = PPrinter()
pprint.assign(lambda pstate, r: True, default_printer)
pprint.assign(lambda pstate, r: isinstance(r, Constant), constant_printer)
......@@ -1705,7 +1705,7 @@ def get_node_by_id(
if isinstance(graphs, Variable):
graphs = (graphs,)
used_ids = dict()
used_ids: Dict[Variable, str] = {}
_ = debugprint(graphs, file="str", used_ids=used_ids, ids=ids)
......
......@@ -19,6 +19,7 @@ from textwrap import dedent
from typing import Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np
from typing_extensions import TypeAlias
import aesara
from aesara import printing
......@@ -106,47 +107,6 @@ def as_common_dtype(*vars):
return (v.astype(dtype) for v in vars)
def get_scalar_type(dtype):
"""
Return a Scalar(dtype) object.
This caches objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
def as_scalar(x, name=None):
from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to Scalar", type(x))
class NumpyAutocaster:
"""
This class is used to cast python ints and floats to numpy arrays.
......@@ -314,12 +274,6 @@ def convert(x, dtype=None):
return x_
def constant(x, name=None, dtype=None):
x = convert(x, dtype=dtype)
assert x.ndim == 0
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
class Scalar(CType):
"""
......@@ -722,6 +676,21 @@ class Scalar(CType):
return shape_info
def get_scalar_type(dtype) -> Scalar:
"""
Return a Scalar(dtype) object.
This caches objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
# Register C code for ViewOp on Scalars.
aesara.compile.register_view_op_c_code(
Scalar,
......@@ -732,30 +701,31 @@ aesara.compile.register_view_op_c_code(
)
bool = get_scalar_type("bool")
int8 = get_scalar_type("int8")
int16 = get_scalar_type("int16")
int32 = get_scalar_type("int32")
int64 = get_scalar_type("int64")
uint8 = get_scalar_type("uint8")
uint16 = get_scalar_type("uint16")
uint32 = get_scalar_type("uint32")
uint64 = get_scalar_type("uint64")
float16 = get_scalar_type("float16")
float32 = get_scalar_type("float32")
float64 = get_scalar_type("float64")
complex64 = get_scalar_type("complex64")
complex128 = get_scalar_type("complex128")
int_types = int8, int16, int32, int64
uint_types = uint8, uint16, uint32, uint64
float_types = float16, float32, float64
complex_types = complex64, complex128
integer_types = int_types + uint_types
discrete_types = (bool,) + integer_types
continuous_types = float_types + complex_types
all_types = discrete_types + continuous_types
bool: Scalar = get_scalar_type("bool")
int8: Scalar = get_scalar_type("int8")
int16: Scalar = get_scalar_type("int16")
int32: Scalar = get_scalar_type("int32")
int64: Scalar = get_scalar_type("int64")
uint8: Scalar = get_scalar_type("uint8")
uint16: Scalar = get_scalar_type("uint16")
uint32: Scalar = get_scalar_type("uint32")
uint64: Scalar = get_scalar_type("uint64")
float16: Scalar = get_scalar_type("float16")
float32: Scalar = get_scalar_type("float32")
float64: Scalar = get_scalar_type("float64")
complex64: Scalar = get_scalar_type("complex64")
complex128: Scalar = get_scalar_type("complex128")
_ScalarTypes: TypeAlias = Tuple[Scalar, ...]
int_types: _ScalarTypes = (int8, int16, int32, int64)
uint_types: _ScalarTypes = (uint8, uint16, uint32, uint64)
float_types: _ScalarTypes = (float16, float32, float64)
complex_types: _ScalarTypes = (complex64, complex128)
integer_types: _ScalarTypes = int_types + uint_types
discrete_types: _ScalarTypes = (bool,) + integer_types
continuous_types: _ScalarTypes = float_types + complex_types
all_types: _ScalarTypes = discrete_types + continuous_types
discrete_dtypes = tuple(t.dtype for t in discrete_types)
......@@ -885,6 +855,38 @@ class ScalarConstant(ScalarVariable, Constant):
Scalar.Constant = ScalarConstant
def constant(x, name=None, dtype=None) -> ScalarConstant:
x = convert(x, dtype=dtype)
assert x.ndim == 0
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
def as_scalar(x, name=None) -> ScalarConstant:
from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to Scalar", type(x))
# Easy constructors
ints = apply_across_args(int64)
......@@ -1276,9 +1278,9 @@ class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields:
# - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative: Optional[bool] = None
associative: Optional[bool] = None
identity: Optional[bool] = None
commutative: Optional[builtins.bool] = None
associative: Optional[builtins.bool] = None
identity: Optional[builtins.bool] = None
"""
For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
......@@ -2501,20 +2503,20 @@ class Cast(UnaryScalarOp):
return s
convert_to_bool = Cast(bool, name="convert_to_bool")
convert_to_int8 = Cast(int8, name="convert_to_int8")
convert_to_int16 = Cast(int16, name="convert_to_int16")
convert_to_int32 = Cast(int32, name="convert_to_int32")
convert_to_int64 = Cast(int64, name="convert_to_int64")
convert_to_uint8 = Cast(uint8, name="convert_to_uint8")
convert_to_uint16 = Cast(uint16, name="convert_to_uint16")
convert_to_uint32 = Cast(uint32, name="convert_to_uint32")
convert_to_uint64 = Cast(uint64, name="convert_to_uint64")
convert_to_float16 = Cast(float16, name="convert_to_float16")
convert_to_float32 = Cast(float32, name="convert_to_float32")
convert_to_float64 = Cast(float64, name="convert_to_float64")
convert_to_complex64 = Cast(complex64, name="convert_to_complex64")
convert_to_complex128 = Cast(complex128, name="convert_to_complex128")
convert_to_bool: Cast = Cast(bool, name="convert_to_bool")
convert_to_int8: Cast = Cast(int8, name="convert_to_int8")
convert_to_int16: Cast = Cast(int16, name="convert_to_int16")
convert_to_int32: Cast = Cast(int32, name="convert_to_int32")
convert_to_int64: Cast = Cast(int64, name="convert_to_int64")
convert_to_uint8: Cast = Cast(uint8, name="convert_to_uint8")
convert_to_uint16: Cast = Cast(uint16, name="convert_to_uint16")
convert_to_uint32: Cast = Cast(uint32, name="convert_to_uint32")
convert_to_uint64: Cast = Cast(uint64, name="convert_to_uint64")
convert_to_float16: Cast = Cast(float16, name="convert_to_float16")
convert_to_float32: Cast = Cast(float32, name="convert_to_float32")
convert_to_float64: Cast = Cast(float64, name="convert_to_float64")
convert_to_complex64: Cast = Cast(complex64, name="convert_to_complex64")
convert_to_complex128: Cast = Cast(complex128, name="convert_to_complex128")
_cast_mapping = {
"bool": convert_to_bool,
......
......@@ -683,7 +683,7 @@ def push_out_inner_vars(
outer_vars = [None] * len(inner_vars)
new_scan_node = old_scan_node
new_scan_args = old_scan_args
replacements = {}
replacements: Dict[Variable, Variable] = {}
# For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them
......
......@@ -1016,7 +1016,7 @@ class ScanArgs:
def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
) -> Optional[Tuple[str, int, int, int]]:
) -> Optional[FieldInfo]:
"""Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element.
......@@ -1158,14 +1158,14 @@ class ScanArgs:
def remove_from_fields(
self, i: Variable, rm_dependents: bool = True
) -> List[FieldInfo]:
) -> List[Tuple[Variable, Optional[FieldInfo]]]:
if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i}
else:
vars_to_remove = {i}
rm_info = []
rm_info: List[Tuple[Variable, Optional[FieldInfo]]] = []
for v in vars_to_remove:
dependent_rm_info = self._remove_from_fields(v)
rm_info.append((v, dependent_rm_info))
......
......@@ -9,7 +9,7 @@ from aesara.graph.op import Op
def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Variable:
) -> "TensorVariable":
"""Convert `x` into an equivalent `TensorVariable`.
This function can be used to turn ndarrays, numbers, `Scalar` instances,
......@@ -45,7 +45,7 @@ def as_tensor_variable(
@singledispatch
def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
) -> NoReturn:
) -> "TensorVariable":
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
......@@ -80,7 +80,7 @@ def get_vector_length(v: Any):
@singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable) -> NoReturn:
def _get_vector_length(op: Union[Op, Variable], var: Variable):
"""`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined")
......
......@@ -184,7 +184,7 @@ def _as_tensor_bool(x, name, ndim, **kwargs):
as_tensor = as_tensor_variable
def constant(x, name=None, ndim=None, dtype=None):
def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant:
"""Return a `TensorConstant` with value `x`.
Raises
......@@ -795,7 +795,7 @@ register_rebroadcast_c_code(
# to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924
def _conversion(real_value, name):
def _conversion(real_value: Op, name: str) -> Op:
__oplist_tag(real_value, "casting")
real_value.__module__ = "tensor.basic"
pprint.assign(real_value, printing.FunctionPrinter([name]))
......@@ -807,46 +807,50 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
_convert_to_bool = _conversion(Elemwise(aes.convert_to_bool), "bool")
_convert_to_bool: Elemwise = _conversion(Elemwise(aes.convert_to_bool), "bool")
"""Cast to boolean"""
_convert_to_int8 = _conversion(Elemwise(aes.convert_to_int8), "int8")
_convert_to_int8: Elemwise = _conversion(Elemwise(aes.convert_to_int8), "int8")
"""Cast to 8-bit integer"""
_convert_to_int16 = _conversion(Elemwise(aes.convert_to_int16), "int16")
_convert_to_int16: Elemwise = _conversion(Elemwise(aes.convert_to_int16), "int16")
"""Cast to 16-bit integer"""
_convert_to_int32 = _conversion(Elemwise(aes.convert_to_int32), "int32")
_convert_to_int32: Elemwise = _conversion(Elemwise(aes.convert_to_int32), "int32")
"""Cast to 32-bit integer"""
_convert_to_int64 = _conversion(Elemwise(aes.convert_to_int64), "int64")
_convert_to_int64: Elemwise = _conversion(Elemwise(aes.convert_to_int64), "int64")
"""Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(Elemwise(aes.convert_to_uint8), "uint8")
_convert_to_uint8: Elemwise = _conversion(Elemwise(aes.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(Elemwise(aes.convert_to_uint16), "uint16")
_convert_to_uint16: Elemwise = _conversion(Elemwise(aes.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(Elemwise(aes.convert_to_uint32), "uint32")
_convert_to_uint32: Elemwise = _conversion(Elemwise(aes.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(Elemwise(aes.convert_to_uint64), "uint64")
_convert_to_uint64: Elemwise = _conversion(Elemwise(aes.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(Elemwise(aes.convert_to_float16), "float16")
_convert_to_float16: Elemwise = _conversion(Elemwise(aes.convert_to_float16), "float16")
"""Cast to half-precision floating point"""
_convert_to_float32 = _conversion(Elemwise(aes.convert_to_float32), "float32")
_convert_to_float32: Elemwise = _conversion(Elemwise(aes.convert_to_float32), "float32")
"""Cast to single-precision floating point"""
_convert_to_float64 = _conversion(Elemwise(aes.convert_to_float64), "float64")
_convert_to_float64: Elemwise = _conversion(Elemwise(aes.convert_to_float64), "float64")
"""Cast to double-precision floating point"""
_convert_to_complex64 = _conversion(Elemwise(aes.convert_to_complex64), "complex64")
_convert_to_complex64: Elemwise = _conversion(
Elemwise(aes.convert_to_complex64), "complex64"
)
"""Cast to single-precision complex"""
_convert_to_complex128 = _conversion(Elemwise(aes.convert_to_complex128), "complex128")
_convert_to_complex128: Elemwise = _conversion(
Elemwise(aes.convert_to_complex128), "complex128"
)
"""Cast to double-precision complex"""
_cast_mapping = {
......@@ -867,7 +871,7 @@ _cast_mapping = {
}
def cast(x, dtype):
def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable:
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if isinstance(dtype, str) and dtype == "floatX":
......
from copy import copy
from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union
import numpy as np
......@@ -27,9 +27,9 @@ from aesara.tensor.var import TensorVariable
def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: Optional[int] = 0,
rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None,
) -> Tuple[int, ...]:
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`.
This is a function that derives a random variable's support
......@@ -171,10 +171,10 @@ class RandomVariable(Op):
def _infer_shape(
self,
size: Tuple[TensorVariable],
dist_params: List[TensorVariable],
param_shapes: Optional[List[Tuple[TensorVariable]]] = None,
) -> Tuple[ScalarVariable]:
size: TensorVariable,
dist_params: Sequence[TensorVariable],
param_shapes: Optional[Sequence[Tuple[Variable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Compute the output shape given the size and distribution parameters.
Parameters
......
from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from typing import Optional, Union
import numpy as np
......@@ -111,7 +112,9 @@ def broadcast_params(params, ndims_params):
return bcast_params
def normalize_size_param(size):
def normalize_size_param(
size: Optional[Union[int, np.ndarray, Variable, Sequence]]
) -> Variable:
"""Create an Aesara value for a ``RandomVariable`` ``size`` parameter."""
if size is None:
size = constant([], dtype="int64")
......
......@@ -16,7 +16,7 @@ from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant
from aesara.tensor.var import TensorConstant, TensorVariable
def register_shape_c_code(type, code, version=()):
......@@ -155,7 +155,7 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim
def shape_tuple(x: Variable) -> Tuple[Variable]:
def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
"""Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
......
......@@ -135,8 +135,8 @@ def as_index_constant(
def as_index_literal(
idx: Union[Variable, slice, type(np.newaxis)]
) -> Union[int, slice, type(np.newaxis)]:
idx: Optional[Union[Variable, slice]]
) -> Optional[Union[int, slice]]:
"""Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant`
......
......@@ -48,6 +48,7 @@ dependencies:
# developer tools
- pre-commit
- packaging
- typing_extensions
# optional
- sympy
- cython
......@@ -52,6 +52,7 @@ install_requires = [
"logical-unification",
"miniKanren",
"cons",
"typing_extensions",
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论