提交 f0392db7 authored 作者: LegrandNico's avatar LegrandNico 提交者: Thomas Wiecki

Fix type hint errors (272->168)

上级 f26a5086
...@@ -14,7 +14,7 @@ import os ...@@ -14,7 +14,7 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Dict
def get_keywords(): def get_keywords():
"""Get the keywords needed to look up the version information.""" """Get the keywords needed to look up the version information."""
...@@ -51,8 +51,8 @@ class NotThisMethod(Exception): ...@@ -51,8 +51,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario.""" """Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY = {} LONG_VERSION_PY: Dict = {}
HANDLERS = {} HANDLERS: Dict = {}
def register_vcs_handler(vcs, method): # decorator def register_vcs_handler(vcs, method): # decorator
......
...@@ -11,6 +11,7 @@ import pickle ...@@ -11,6 +11,7 @@ import pickle
import time import time
import warnings import warnings
from itertools import chain from itertools import chain
from typing import List
import numpy as np import numpy as np
...@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs): ...@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs):
return None return None
__checkers = [] __checkers: List = []
def check_equal(x, y): def check_equal(x, y):
......
...@@ -5,6 +5,7 @@ WRITEME ...@@ -5,6 +5,7 @@ WRITEME
import logging import logging
import warnings import warnings
from typing import Tuple, Union
import aesara import aesara
from aesara.compile.function.types import Supervisor from aesara.compile.function.types import Supervisor
...@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in ...@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in
# final pass just to make sure # final pass just to make sure
optdb.register("merge3", MergeOptimizer(), 100, "fast_run", "merge") optdb.register("merge3", MergeOptimizer(), 100, "fast_run", "merge")
_tags: Union[Tuple[str, str], Tuple]
if config.check_stack_trace in ["raise", "warn", "log"]: if config.check_stack_trace in ["raise", "warn", "log"]:
_tags = ("fast_run", "fast_compile") _tags = ("fast_run", "fast_compile")
......
...@@ -8,6 +8,7 @@ help make new Ops more rapidly. ...@@ -8,6 +8,7 @@ help make new Ops more rapidly.
import copy import copy
import pickle import pickle
import warnings import warnings
from typing import Dict, Tuple
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
...@@ -42,9 +43,9 @@ class ViewOp(COp): ...@@ -42,9 +43,9 @@ class ViewOp(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
__props__ = () __props__: Tuple = ()
_f16_ok = True _f16_ok: bool = True
def make_node(self, x): def make_node(self, x):
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
...@@ -148,11 +149,11 @@ class DeepCopyOp(COp): ...@@ -148,11 +149,11 @@ class DeepCopyOp(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
check_input = False check_input: bool = False
__props__ = () __props__: Tuple = ()
_f16_ok = True _f16_ok: bool = True
def __init__(self): def __init__(self):
pass pass
......
...@@ -17,6 +17,7 @@ import sys ...@@ -17,6 +17,7 @@ import sys
import time import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List
import numpy as np import numpy as np
...@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0 ...@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0
total_graph_opt_time = 0.0 total_graph_opt_time = 0.0
total_time_linker = 0.0 total_time_linker = 0.0
_atexit_print_list = [] _atexit_print_list: List = []
_atexit_registered = False _atexit_registered = False
...@@ -242,15 +243,15 @@ class ProfileStats: ...@@ -242,15 +243,15 @@ class ProfileStats:
# pretty string to print in summary, to identify this output # pretty string to print in summary, to identify this output
# #
variable_shape = {} variable_shape: Dict = {}
# Variable -> shapes # Variable -> shapes
# #
variable_strides = {} variable_strides: Dict = {}
# Variable -> strides # Variable -> strides
# #
variable_offset = {} variable_offset: Dict = {}
# Variable -> offset # Variable -> offset
# #
...@@ -270,7 +271,7 @@ class ProfileStats: ...@@ -270,7 +271,7 @@ class ProfileStats:
linker_node_make_thunks = 0.0 linker_node_make_thunks = 0.0
linker_make_thunk_time = {} linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width line_width = config.profiling__output_line_width
......
...@@ -13,6 +13,7 @@ from configparser import ( ...@@ -13,6 +13,7 @@ from configparser import (
) )
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
from typing import Optional
from aesara.utils import deprecated, hash_from_code from aesara.utils import deprecated, hash_from_code
...@@ -94,7 +95,7 @@ class AesaraConfigParser: ...@@ -94,7 +95,7 @@ class AesaraConfigParser:
self._flags_dict = flags_dict self._flags_dict = flags_dict
self._aesara_cfg = aesara_cfg self._aesara_cfg = aesara_cfg
self._aesara_raw_cfg = aesara_raw_cfg self._aesara_raw_cfg = aesara_raw_cfg
self._config_var_dict = {} self._config_var_dict: typing.Dict = {}
super().__init__() super().__init__()
def __str__(self, print_doc=True): def __str__(self, print_doc=True):
...@@ -325,7 +326,7 @@ class ConfigParam: ...@@ -325,7 +326,7 @@ class ConfigParam:
return self._apply(value) return self._apply(value)
return value return value
def validate(self, value) -> None: def validate(self, value) -> Optional[bool]:
"""Validates that a parameter values falls into a supported set or range. """Validates that a parameter values falls into a supported set or range.
Raises Raises
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import os import os
import re import re
from collections import deque from collections import deque
from typing import Union
import numpy as np import numpy as np
...@@ -306,7 +307,7 @@ class GpuKernelBase: ...@@ -306,7 +307,7 @@ class GpuKernelBase:
""" """
params_type = gpu_context_type params_type: Union[ParamsType, gpu_context_type] = gpu_context_type
def get_params(self, node): def get_params(self, node):
# Default implementation, suitable for most sub-classes. # Default implementation, suitable for most sub-classes.
......
import os import os
import sys import sys
from typing import Set
class PathParser: class PathParser:
...@@ -25,7 +26,7 @@ class PathParser: ...@@ -25,7 +26,7 @@ class PathParser:
""" """
paths = set() paths: Set = set()
def _add(self, path): def _add(self, path):
path = path.strip() path = path.strip()
......
...@@ -571,7 +571,7 @@ class Variable(Node): ...@@ -571,7 +571,7 @@ class Variable(Node):
return d return d
# refer to doc in nodes_constructed. # refer to doc in nodes_constructed.
construction_observers = [] construction_observers: List = []
@classmethod @classmethod
def append_construction_observer(cls, observer): def append_construction_observer(cls, observer):
...@@ -700,10 +700,11 @@ def walk( ...@@ -700,10 +700,11 @@ def walk(
rval_set: Set[T] = set() rval_set: Set[T] = set()
nodes_pop: Callable[[], T]
if bfs: if bfs:
nodes_pop: Callable[[], T] = nodes.popleft nodes_pop = nodes.popleft
else: else:
nodes_pop: Callable[[], T] = nodes.pop nodes_pop = nodes.pop
while nodes: while nodes:
node: T = nodes_pop() node: T = nodes_pop()
...@@ -714,7 +715,7 @@ def walk( ...@@ -714,7 +715,7 @@ def walk(
rval_set.add(node_hash) rval_set.add(node_hash)
new_nodes: Sequence[T] = expand(node) new_nodes: Optional[Sequence[T]] = expand(node)
if return_children: if return_children:
yield node, new_nodes yield node, new_nodes
...@@ -862,8 +863,8 @@ def applys_between( ...@@ -862,8 +863,8 @@ def applys_between(
def clone( def clone(
inputs: Collection[Variable], inputs: List[Variable],
outputs: Collection[Variable], outputs: List[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: Optional[bool] = None, copy_orphans: Optional[bool] = None,
) -> Tuple[Collection[Variable], Collection[Variable]]: ) -> Tuple[Collection[Variable], Collection[Variable]]:
...@@ -902,8 +903,8 @@ def clone( ...@@ -902,8 +903,8 @@ def clone(
def clone_get_equiv( def clone_get_equiv(
inputs: Collection[Variable], inputs: List[Variable],
outputs: Collection[Variable], outputs: List[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict[Variable, Variable]] = None,
...@@ -1171,7 +1172,7 @@ def io_toposort( ...@@ -1171,7 +1172,7 @@ def io_toposort(
compute_deps = None compute_deps = None
compute_deps_cache = None compute_deps_cache = None
iset = set(inputs) iset = set(inputs)
deps_cache = {} deps_cache: Dict = {}
if not orderings: # ordering can be None or empty dict if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering. # Specialized function that is faster when no ordering.
...@@ -1345,8 +1346,8 @@ def as_string( ...@@ -1345,8 +1346,8 @@ def as_string(
orph = list(orphans_between(i, outputs)) orph = list(orphans_between(i, outputs))
multi = set() multi: Set = set()
seen = set() seen: Set = set()
for output in outputs: for output in outputs:
op = output.owner op = output.owner
if op in seen: if op in seen:
...@@ -1362,8 +1363,8 @@ def as_string( ...@@ -1362,8 +1363,8 @@ def as_string(
multi.add(op2) multi.add(op2)
else: else:
seen.add(input.owner) seen.add(input.owner)
multi = [x for x in multi] multi: Set = [x for x in multi]
done = set() done: Set = set()
def multi_index(x): def multi_index(x):
return multi.index(x) + 1 return multi.index(x) + 1
......
...@@ -157,7 +157,7 @@ class Op(MetaObject): ...@@ -157,7 +157,7 @@ class Op(MetaObject):
""" """
default_output = None default_output: Optional[int] = None
""" """
An `int` that specifies which output `Op.__call__` should return. If An `int` that specifies which output `Op.__call__` should return. If
`None`, then all outputs are returned. `None`, then all outputs are returned.
...@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text: ...@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text:
return "\n".join(res) return "\n".join(res)
def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]: def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text]]:
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}") define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
...@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]: ...@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
return "\n".join(define_macros), "\n".join(undef_macros) return "\n".join(define_macros), "\n".join(undef_macros)
def get_io_macros(inputs: List[Text], outputs: List[Text]) -> Tuple[List[Text]]: def get_io_macros(
inputs: List[Text], outputs: List[Text]
) -> Union[Tuple[List[Text]], Tuple[str, str]]:
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
...@@ -1023,7 +1025,7 @@ class ExternalCOp(COp): ...@@ -1023,7 +1025,7 @@ class ExternalCOp(COp):
f"No valid section marker was found in file {func_files[i]}" f"No valid section marker was found in file {func_files[i]}"
) )
def __get_op_params(self) -> List[Text]: def __get_op_params(self) -> Union[List[Text], List[Tuple[str, Any]]]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code. """Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the The names must be strings that are not a C keyword and the
...@@ -1130,7 +1132,7 @@ class ExternalCOp(COp): ...@@ -1130,7 +1132,7 @@ class ExternalCOp(COp):
def get_c_macros( def get_c_macros(
self, node: Apply, name: Text, check_input: Optional[bool] = None self, node: Apply, name: Text, check_input: Optional[bool] = None
) -> Tuple[Text]: ) -> Union[Tuple[str], Tuple[str, str]]:
"Construct a pair of C ``#define`` and ``#undef`` code strings." "Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s" define_template = "#define %s %s"
undef_template = "#undef %s" undef_template = "#undef %s"
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import math import math
import sys import sys
from io import StringIO from io import StringIO
from typing import Dict, Optional
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import opt from aesara.graph import opt
...@@ -192,6 +193,7 @@ class Query: ...@@ -192,6 +193,7 @@ class Query:
self.exclude = exclude or OrderedSet() self.exclude = exclude or OrderedSet()
self.subquery = subquery or {} self.subquery = subquery or {}
self.position_cutoff = position_cutoff self.position_cutoff = position_cutoff
self.name: Optional[str] = None
if extra_optimizations is None: if extra_optimizations is None:
extra_optimizations = [] extra_optimizations = []
self.extra_optimizations = extra_optimizations self.extra_optimizations = extra_optimizations
...@@ -438,14 +440,18 @@ class LocalGroupDB(DB): ...@@ -438,14 +440,18 @@ class LocalGroupDB(DB):
""" """
def __init__( def __init__(
self, apply_all_opts=False, profile=False, local_opt=opt.LocalOptGroup self,
apply_all_opts: bool = False,
profile: bool = False,
local_opt=opt.LocalOptGroup,
): ):
super().__init__() super().__init__()
self.failure_callback = None self.failure_callback = None
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
self.__position__ = {} self.__position__: Dict = {}
self.local_opt = local_opt self.local_opt = local_opt
self.__name__: str = ""
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags) super().register(name, obj, *tags)
......
...@@ -331,7 +331,7 @@ def unify_walk(a, b, U): ...@@ -331,7 +331,7 @@ def unify_walk(a, b, U):
return False return False
@comm_guard(FreeVariable, ANY_TYPE) # noqa @comm_guard(FreeVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(fv, o, U): def unify_walk(fv, o, U):
""" """
FreeV is unified to BoundVariable(other_object). FreeV is unified to BoundVariable(other_object).
...@@ -341,7 +341,7 @@ def unify_walk(fv, o, U): ...@@ -341,7 +341,7 @@ def unify_walk(fv, o, U):
return U.merge(v, fv) return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE) # noqa @comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(bv, o, U): def unify_walk(bv, o, U):
""" """
The unification succeed iff BV.value == other_object. The unification succeed iff BV.value == other_object.
...@@ -353,7 +353,7 @@ def unify_walk(bv, o, U): ...@@ -353,7 +353,7 @@ def unify_walk(bv, o, U):
return False return False
@comm_guard(OrVariable, ANY_TYPE) # noqa @comm_guard(OrVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(ov, o, U): def unify_walk(ov, o, U):
""" """
The unification succeeds iff other_object in OrV.options. The unification succeeds iff other_object in OrV.options.
...@@ -366,7 +366,7 @@ def unify_walk(ov, o, U): ...@@ -366,7 +366,7 @@ def unify_walk(ov, o, U):
return False return False
@comm_guard(NotVariable, ANY_TYPE) # noqa @comm_guard(NotVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(nv, o, U): def unify_walk(nv, o, U):
""" """
The unification succeeds iff other_object not in NV.not_options. The unification succeeds iff other_object not in NV.not_options.
...@@ -379,7 +379,7 @@ def unify_walk(nv, o, U): ...@@ -379,7 +379,7 @@ def unify_walk(nv, o, U):
return U.merge(v, nv) return U.merge(v, nv)
@comm_guard(FreeVariable, Variable) # noqa @comm_guard(FreeVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(fv, v, U): def unify_walk(fv, v, U):
""" """
Both variables are unified. Both variables are unified.
...@@ -389,7 +389,7 @@ def unify_walk(fv, v, U): ...@@ -389,7 +389,7 @@ def unify_walk(fv, v, U):
return U.merge(v, fv) return U.merge(v, fv)
@comm_guard(BoundVariable, Variable) # noqa @comm_guard(BoundVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(bv, v, U): def unify_walk(bv, v, U):
""" """
V is unified to BV.value. V is unified to BV.value.
...@@ -398,7 +398,7 @@ def unify_walk(bv, v, U): ...@@ -398,7 +398,7 @@ def unify_walk(bv, v, U):
return unify_walk(v, bv.value, U) return unify_walk(v, bv.value, U)
@comm_guard(OrVariable, OrVariable) # noqa @comm_guard(OrVariable, OrVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2)) OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
...@@ -414,7 +414,7 @@ def unify_walk(a, b, U): ...@@ -414,7 +414,7 @@ def unify_walk(a, b, U):
return U.merge(v, a, b) return U.merge(v, a, b)
@comm_guard(NotVariable, NotVariable) # noqa @comm_guard(NotVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
NV(list1) == NV(list2) == NV(union(list1, list2)) NV(list1) == NV(list2) == NV(union(list1, list2))
...@@ -425,7 +425,7 @@ def unify_walk(a, b, U): ...@@ -425,7 +425,7 @@ def unify_walk(a, b, U):
return U.merge(v, a, b) return U.merge(v, a, b)
@comm_guard(OrVariable, NotVariable) # noqa @comm_guard(OrVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(o, n, U): def unify_walk(o, n, U):
r""" r"""
OrV(list1) == NV(list2) == OrV(list1 \ list2) OrV(list1) == NV(list2) == OrV(list1 \ list2)
...@@ -441,7 +441,7 @@ def unify_walk(o, n, U): ...@@ -441,7 +441,7 @@ def unify_walk(o, n, U):
return U.merge(v, o, n) return U.merge(v, o, n)
@comm_guard(VariableInList, (list, tuple)) # noqa @comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(vil, l, U): def unify_walk(vil, l, U):
""" """
Unifies VIL's inner Variable to OrV(list). Unifies VIL's inner Variable to OrV(list).
...@@ -452,7 +452,7 @@ def unify_walk(vil, l, U): ...@@ -452,7 +452,7 @@ def unify_walk(vil, l, U):
return unify_walk(v, ov, U) return unify_walk(v, ov, U)
@comm_guard((list, tuple), (list, tuple)) # noqa @comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(l1, l2, U): def unify_walk(l1, l2, U):
""" """
Tries to unify each corresponding pair of elements from l1 and l2. Tries to unify each corresponding pair of elements from l1 and l2.
...@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U): ...@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U):
return U return U
@comm_guard(dict, dict) # noqa @comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_walk(d1, d2, U): def unify_walk(d1, d2, U):
""" """
Tries to unify values of corresponding keys. Tries to unify values of corresponding keys.
...@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U): ...@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U):
return U return U
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa @comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
Checks for the existence of the __unify_walk__ method for one of Checks for the existence of the __unify_walk__ method for one of
...@@ -498,7 +498,7 @@ def unify_walk(a, b, U): ...@@ -498,7 +498,7 @@ def unify_walk(a, b, U):
return FALL_THROUGH return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) # noqa @comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(v, o, U): def unify_walk(v, o, U):
""" """
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
...@@ -528,27 +528,27 @@ def unify_merge(a, b, U): ...@@ -528,27 +528,27 @@ def unify_merge(a, b, U):
return a return a
@comm_guard(Variable, ANY_TYPE) # noqa @comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U): def unify_merge(v, o, U):
return v return v
@comm_guard(BoundVariable, ANY_TYPE) # noqa @comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(bv, o, U): def unify_merge(bv, o, U):
return bv.value return bv.value
@comm_guard(VariableInList, (list, tuple)) # noqa @comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(vil, l, U): def unify_merge(vil, l, U):
return [unify_merge(x, x, U) for x in l] return [unify_merge(x, x, U) for x in l]
@comm_guard((list, tuple), (list, tuple)) # noqa @comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(l1, l2, U): def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)] return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
@comm_guard(dict, dict) # noqa @comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_merge(d1, d2, U): def unify_merge(d1, d2, U):
d = d1.__class__() d = d1.__class__()
for k1, v1 in d1.items(): for k1, v1 in d1.items():
...@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U): ...@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U):
return d return d
@comm_guard(FVar, ANY_TYPE) # noqa @comm_guard(FVar, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(vs, o, U): def unify_merge(vs, o, U):
return vs(U) return vs(U)
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa @comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(a, b, U): def unify_merge(a, b, U):
if ( if (
not isinstance(a, Variable) not isinstance(a, Variable)
...@@ -579,7 +579,7 @@ def unify_merge(a, b, U): ...@@ -579,7 +579,7 @@ def unify_merge(a, b, U):
return FALL_THROUGH return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) # noqa @comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U): def unify_merge(v, o, U):
""" """
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
......
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import traceback import traceback
from abc import ABCMeta from abc import ABCMeta
from io import StringIO from io import StringIO
from typing import List
def simple_extract_stack(f=None, limit=None, skips=None): def simple_extract_stack(f=None, limit=None, skips=None):
...@@ -225,7 +226,7 @@ class MetaType(ABCMeta): ...@@ -225,7 +226,7 @@ class MetaType(ABCMeta):
class MetaObject(metaclass=MetaType): class MetaObject(metaclass=MetaType):
__slots__ = [] __slots__: List = []
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
......
...@@ -8,6 +8,7 @@ import sys ...@@ -8,6 +8,7 @@ import sys
from collections import defaultdict from collections import defaultdict
from copy import copy from copy import copy
from io import StringIO from io import StringIO
from typing import Dict
import numpy as np import numpy as np
...@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker): ...@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker):
""" """
__cache__ = {} __cache__: Dict = {}
def __init__( def __init__(
self, fallback_on_perform=True, allow_gc=None, nice_errors=True, schedule=None self, fallback_on_perform=True, allow_gc=None, nice_errors=True, schedule=None
......
...@@ -19,6 +19,7 @@ import textwrap ...@@ -19,6 +19,7 @@ import textwrap
import time import time
import warnings import warnings
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Dict, List, Set
import numpy.distutils import numpy.distutils
...@@ -631,38 +632,38 @@ class ModuleCache: ...@@ -631,38 +632,38 @@ class ModuleCache:
""" """
dirname = "" dirname: str = ""
""" """
The working directory that is managed by this interface. The working directory that is managed by this interface.
""" """
module_from_name = {} module_from_name: Dict = {}
""" """
Maps a module filename to the loaded module object. Maps a module filename to the loaded module object.
""" """
entry_from_key = {} entry_from_key: Dict = {}
""" """
Maps keys to the filename of a .so/.pyd. Maps keys to the filename of a .so/.pyd.
""" """
similar_keys = {} similar_keys: Dict = {}
""" """
Maps a part-of-key to all keys that share this same part. Maps a part-of-key to all keys that share this same part.
""" """
module_hash_to_key_data = {} module_hash_to_key_data: Dict = {}
""" """
Maps a module hash to its corresponding KeyData object. Maps a module hash to its corresponding KeyData object.
""" """
stats = [] stats: List = []
""" """
A list with counters for the number of hits, loads, compiles issued by A list with counters for the number of hits, loads, compiles issued by
module_from_key(). module_from_key().
""" """
loaded_key_pkl = set() loaded_key_pkl: Set = set()
""" """
Set of all key.pkl files that have been loaded. Set of all key.pkl files that have been loaded.
......
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Dict, List, Text, Tuple from typing import Callable, Dict, List, Text, Tuple, Union
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
...@@ -129,7 +129,7 @@ class CLinkerObject: ...@@ -129,7 +129,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization.""" """Return a list of code snippets to be inserted in module initialization."""
return [] return []
def c_code_cache_version(self) -> Tuple[int]: def c_code_cache_version(self) -> Union[Tuple[int], Tuple]:
"""Return a tuple of integers indicating the version of this `Op`. """Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an 'unversioned' `Op` that will not be cached An empty tuple indicates an 'unversioned' `Op` that will not be cached
...@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject): ...@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject):
""" """
return "" return ""
def c_code_cache_version(self) -> Tuple[int]: def c_code_cache_version(self) -> Union[Tuple, Tuple[int]]:
"""Return a tuple of integers indicating the version of this type. """Return a tuple of integers indicating the version of this type.
An empty tuple indicates an 'unversioned' type that will not An empty tuple indicates an 'unversioned' type that will not
......
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import traceback import traceback
import warnings import warnings
from operator import itemgetter from operator import itemgetter
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -315,6 +315,10 @@ def raise_with_op( ...@@ -315,6 +315,10 @@ def raise_with_op(
types = [getattr(ipt, "type", "No type") for ipt in node.inputs] types = [getattr(ipt, "type", "No type") for ipt in node.inputs]
detailed_err_msg += f"\nInputs types: {types}\n" detailed_err_msg += f"\nInputs types: {types}\n"
shapes: Union[List, str]
strides: Union[List, str]
scalar_values: Union[List, str]
if thunk is not None: if thunk is not None:
if hasattr(thunk, "inputs"): if hasattr(thunk, "inputs"):
shapes = [getattr(ipt[0], "shape", "No shapes") for ipt in thunk.inputs] shapes = [getattr(ipt[0], "shape", "No shapes") for ipt in thunk.inputs]
......
import os import os
import pickle import pickle
import sys import sys
from typing import Dict
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -15,8 +16,8 @@ if len(sys.argv) > 1: ...@@ -15,8 +16,8 @@ if len(sys.argv) > 1:
else: else:
dirs = os.listdir(config.compiledir) dirs = os.listdir(config.compiledir)
dirs = [os.path.join(config.compiledir, d) for d in dirs] dirs = [os.path.join(config.compiledir, d) for d in dirs]
keys = {} # key -> nb seen keys: Dict = {} # key -> nb seen
mods = {} mods: Dict = {}
for dir in dirs: for dir in dirs:
key = None key = None
...@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS: ...@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS:
if v > 1: if v > 1:
print("Duplicate key (%i copies): %s" % (v, pickle.loads(k))) print("Duplicate key (%i copies): %s" % (v, pickle.loads(k)))
nbs_keys = {} # nb seen -> now many key nbs_keys: Dict = {} # nb seen -> now many key
for val in keys.values(): for val in keys.values():
nbs_keys.setdefault(val, 0) nbs_keys.setdefault(val, 0)
nbs_keys[val] += 1 nbs_keys[val] += 1
nbs_mod = {} # nb seen -> how many key nbs_mod: Dict = {} # nb seen -> how many key
nbs_mod_to_key = {} # nb seen -> keys nbs_mod_to_key = {} # nb seen -> keys
more_than_one = 0 more_than_one = 0
for mod, kk in mods.items(): for mod, kk in mods.items():
......
import copy import copy
import warnings import warnings
from typing import Tuple, Union
import numpy as np import numpy as np
...@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp): ...@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype' TODO : need description for parameter 'odtype'
""" """
__props__ = ("odtype",) __props__: Union[Tuple[str], Tuple[str, str]] = ("odtype",)
def __init__(self, odtype): def __init__(self, odtype):
self.odtype = odtype self.odtype = odtype
......
...@@ -15,7 +15,7 @@ from collections.abc import Callable ...@@ -15,7 +15,7 @@ from collections.abc import Callable
from copy import copy from copy import copy
from itertools import chain from itertools import chain
from textwrap import dedent from textwrap import dedent
from typing import Optional from typing import Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np import numpy as np
...@@ -898,7 +898,7 @@ def upgrade_to_float(*types): ...@@ -898,7 +898,7 @@ def upgrade_to_float(*types):
Upgrade any int types to float32 or float64 to avoid losing precision. Upgrade any int types to float32 or float64 to avoid losing precision.
""" """
conv = { conv: Mapping[Type, Type] = {
bool: float32, bool: float32,
int8: float32, int8: float32,
int16: float32, int16: float32,
...@@ -3954,7 +3954,7 @@ class Composite(ScalarOp): ...@@ -3954,7 +3954,7 @@ class Composite(ScalarOp):
""" """
init_param = ("inputs", "outputs") init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
...@@ -4331,7 +4331,7 @@ class Composite(ScalarOp): ...@@ -4331,7 +4331,7 @@ class Composite(ScalarOp):
class Compositef32: class Compositef32:
# This is a dict of scalar op classes that need special handling # This is a dict of scalar op classes that need special handling
special = {} special: Dict = {}
def apply(self, fgraph): def apply(self, fgraph):
mapping = {} mapping = {}
......
...@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en" ...@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en"
import warnings import warnings
from functools import singledispatch from functools import singledispatch
from typing import Callable, Optional from typing import Callable, NoReturn, Optional
def as_tensor_variable( def as_tensor_variable(
x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
): ) -> Callable:
"""Convert `x` into the appropriate `TensorType`. """Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to This function is often used by `make_node` methods of `Op` subclasses to
...@@ -39,7 +39,9 @@ def as_tensor_variable( ...@@ -39,7 +39,9 @@ def as_tensor_variable(
@singledispatch @singledispatch
def _as_tensor_variable(x, name: Optional[str], ndim: Optional[int], **kwargs): def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
) -> NoReturn:
raise NotImplementedError("") raise NotImplementedError("")
......
...@@ -11,6 +11,7 @@ import warnings ...@@ -11,6 +11,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from numbers import Number from numbers import Number
from typing import Dict
import numpy as np import numpy as np
...@@ -679,7 +680,7 @@ class Rebroadcast(COp): ...@@ -679,7 +680,7 @@ class Rebroadcast(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
check_input = False check_input = False
__props__ = ("axis",) __props__ = ("axis",)
......
...@@ -140,6 +140,7 @@ except ImportError: ...@@ -140,6 +140,7 @@ except ImportError:
pass pass
from functools import reduce from functools import reduce
from typing import Tuple, Union
import aesara.scalar import aesara.scalar
from aesara.compile.mode import optdb from aesara.compile.mode import optdb
...@@ -506,7 +507,7 @@ class GemmRelated(COp): ...@@ -506,7 +507,7 @@ class GemmRelated(COp):
""" """
__props__ = () __props__: Union[Tuple, Tuple[str]] = ()
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
# return cblas_header_text() # return cblas_header_text()
......
from copy import copy from copy import copy
from typing import Tuple, Union
import numpy as np import numpy as np
...@@ -1297,7 +1298,9 @@ class CAReduce(COp): ...@@ -1297,7 +1298,9 @@ class CAReduce(COp):
""" """
__props__ = ("scalar_op", "axis") __props__: Union[
Tuple[str], Tuple[str, str], Tuple[str, str, str], Tuple[str, str, str, str]
] = ("scalar_op", "axis")
def __init__(self, scalar_op, axis=None): def __init__(self, scalar_op, axis=None):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1: if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
...@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce): ...@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce):
""" """
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype") __props__: Union[Tuple[str, str, str], Tuple[str, str, str, str]] = (
"scalar_op",
"axis",
"dtype",
"acc_dtype",
)
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None): def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
super().__init__(scalar_op, axis=axis) super().__init__(scalar_op, axis=axis)
......
import logging import logging
from functools import partial from functools import partial
from typing import Tuple, Union
import numpy as np import numpy as np
...@@ -224,7 +225,7 @@ class Eig(Op): ...@@ -224,7 +225,7 @@ class Eig(Op):
""" """
_numop = staticmethod(np.linalg.eig) _numop = staticmethod(np.linalg.eig)
__props__ = () __props__: Union[Tuple, Tuple[str]] = ()
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
......
from typing import List
import numpy as np import numpy as np
import aesara import aesara
...@@ -26,7 +28,7 @@ class SparseBlockGemv(Op): ...@@ -26,7 +28,7 @@ class SparseBlockGemv(Op):
__props__ = ("inplace",) __props__ = ("inplace",)
registered_opts = [] registered_opts: List = []
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace = inplace self.inplace = inplace
...@@ -147,7 +149,7 @@ class SparseBlockOuter(Op): ...@@ -147,7 +149,7 @@ class SparseBlockOuter(Op):
__props__ = ("inplace",) __props__ = ("inplace",)
registered_opts = [] registered_opts: List = []
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace = inplace self.inplace = inplace
......
import warnings import warnings
from typing import Dict
import numpy as np import numpy as np
...@@ -49,7 +50,7 @@ class Shape(COp): ...@@ -49,7 +50,7 @@ class Shape(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
check_input = False check_input = False
__props__ = () __props__ = ()
...@@ -143,7 +144,7 @@ class Shape_i(COp): ...@@ -143,7 +144,7 @@ class Shape_i(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
check_input = False check_input = False
...@@ -363,7 +364,7 @@ class SpecifyShape(COp): ...@@ -363,7 +364,7 @@ class SpecifyShape(COp):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version: Dict = {}
__props__ = () __props__ = ()
_f16_ok = True _f16_ok = True
......
...@@ -12,6 +12,7 @@ import warnings ...@@ -12,6 +12,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from functools import partial, wraps from functools import partial, wraps
from typing import List, Set
__all__ = [ __all__ = [
...@@ -29,7 +30,7 @@ __all__ = [ ...@@ -29,7 +30,7 @@ __all__ = [
] ]
__excepthooks = [] __excepthooks: List = []
LOCAL_BITWIDTH = struct.calcsize("P") * 8 LOCAL_BITWIDTH = struct.calcsize("P") * 8
...@@ -374,7 +375,7 @@ def apply_across_args(*fns): ...@@ -374,7 +375,7 @@ def apply_across_args(*fns):
class NoDuplicateOptWarningFilter(logging.Filter): class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings.""" """Filter to avoid duplicating optimization warnings."""
prev_msgs = set() prev_msgs: Set = set()
def filter(self, record): def filter(self, record):
msg = record.getMessage() msg = record.getMessage()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论