提交 f0392db7 authored 作者: LegrandNico's avatar LegrandNico 提交者: Thomas Wiecki

Fix type hint errors (272->168)

上级 f26a5086
......@@ -14,7 +14,7 @@ import os
import re
import subprocess
import sys
from typing import Dict
def get_keywords():
"""Get the keywords needed to look up the version information."""
......@@ -51,8 +51,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY = {}
HANDLERS = {}
LONG_VERSION_PY: Dict = {}
HANDLERS: Dict = {}
def register_vcs_handler(vcs, method): # decorator
......
......@@ -11,6 +11,7 @@ import pickle
import time
import warnings
from itertools import chain
from typing import List
import numpy as np
......@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs):
return None
__checkers = []
__checkers: List = []
def check_equal(x, y):
......
......@@ -5,6 +5,7 @@ WRITEME
import logging
import warnings
from typing import Tuple, Union
import aesara
from aesara.compile.function.types import Supervisor
......@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in
# final pass just to make sure
optdb.register("merge3", MergeOptimizer(), 100, "fast_run", "merge")
_tags: Union[Tuple[str, str], Tuple]
if config.check_stack_trace in ["raise", "warn", "log"]:
_tags = ("fast_run", "fast_compile")
......
......@@ -8,6 +8,7 @@ help make new Ops more rapidly.
import copy
import pickle
import warnings
from typing import Dict, Tuple
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
......@@ -42,9 +43,9 @@ class ViewOp(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
__props__ = ()
_f16_ok = True
c_code_and_version: Dict = {}
__props__: Tuple = ()
_f16_ok: bool = True
def make_node(self, x):
return Apply(self, [x], [x.type()])
......@@ -148,11 +149,11 @@ class DeepCopyOp(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
c_code_and_version: Dict = {}
check_input = False
__props__ = ()
_f16_ok = True
check_input: bool = False
__props__: Tuple = ()
_f16_ok: bool = True
def __init__(self):
pass
......
......@@ -17,6 +17,7 @@ import sys
import time
import warnings
from collections import defaultdict
from typing import Dict, List
import numpy as np
......@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0
total_graph_opt_time = 0.0
total_time_linker = 0.0
_atexit_print_list = []
_atexit_print_list: List = []
_atexit_registered = False
......@@ -242,15 +243,15 @@ class ProfileStats:
# pretty string to print in summary, to identify this output
#
variable_shape = {}
variable_shape: Dict = {}
# Variable -> shapes
#
variable_strides = {}
variable_strides: Dict = {}
# Variable -> strides
#
variable_offset = {}
variable_offset: Dict = {}
# Variable -> offset
#
......@@ -270,7 +271,7 @@ class ProfileStats:
linker_node_make_thunks = 0.0
linker_make_thunk_time = {}
linker_make_thunk_time: Dict = {}
line_width = config.profiling__output_line_width
......
......@@ -13,6 +13,7 @@ from configparser import (
)
from functools import wraps
from io import StringIO
from typing import Optional
from aesara.utils import deprecated, hash_from_code
......@@ -94,7 +95,7 @@ class AesaraConfigParser:
self._flags_dict = flags_dict
self._aesara_cfg = aesara_cfg
self._aesara_raw_cfg = aesara_raw_cfg
self._config_var_dict = {}
self._config_var_dict: typing.Dict = {}
super().__init__()
def __str__(self, print_doc=True):
......@@ -325,7 +326,7 @@ class ConfigParam:
return self._apply(value)
return value
def validate(self, value) -> None:
def validate(self, value) -> Optional[bool]:
"""Validates that a parameter values falls into a supported set or range.
Raises
......
......@@ -2,6 +2,7 @@ import copy
import os
import re
from collections import deque
from typing import Union
import numpy as np
......@@ -306,7 +307,7 @@ class GpuKernelBase:
"""
params_type = gpu_context_type
params_type: Union[ParamsType, gpu_context_type] = gpu_context_type
def get_params(self, node):
# Default implementation, suitable for most sub-classes.
......
import os
import sys
from typing import Set
class PathParser:
......@@ -25,7 +26,7 @@ class PathParser:
"""
paths = set()
paths: Set = set()
def _add(self, path):
path = path.strip()
......
......@@ -571,7 +571,7 @@ class Variable(Node):
return d
# refer to doc in nodes_constructed.
construction_observers = []
construction_observers: List = []
@classmethod
def append_construction_observer(cls, observer):
......@@ -700,10 +700,11 @@ def walk(
rval_set: Set[T] = set()
nodes_pop: Callable[[], T]
if bfs:
nodes_pop: Callable[[], T] = nodes.popleft
nodes_pop = nodes.popleft
else:
nodes_pop: Callable[[], T] = nodes.pop
nodes_pop = nodes.pop
while nodes:
node: T = nodes_pop()
......@@ -714,7 +715,7 @@ def walk(
rval_set.add(node_hash)
new_nodes: Sequence[T] = expand(node)
new_nodes: Optional[Sequence[T]] = expand(node)
if return_children:
yield node, new_nodes
......@@ -862,8 +863,8 @@ def applys_between(
def clone(
inputs: Collection[Variable],
outputs: Collection[Variable],
inputs: List[Variable],
outputs: List[Variable],
copy_inputs: bool = True,
copy_orphans: Optional[bool] = None,
) -> Tuple[Collection[Variable], Collection[Variable]]:
......@@ -902,8 +903,8 @@ def clone(
def clone_get_equiv(
inputs: Collection[Variable],
outputs: Collection[Variable],
inputs: List[Variable],
outputs: List[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None,
......@@ -1171,7 +1172,7 @@ def io_toposort(
compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache = {}
deps_cache: Dict = {}
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
......@@ -1345,8 +1346,8 @@ def as_string(
orph = list(orphans_between(i, outputs))
multi = set()
seen = set()
multi: Set = set()
seen: Set = set()
for output in outputs:
op = output.owner
if op in seen:
......@@ -1362,8 +1363,8 @@ def as_string(
multi.add(op2)
else:
seen.add(input.owner)
multi = [x for x in multi]
done = set()
multi: Set = [x for x in multi]
done: Set = set()
def multi_index(x):
return multi.index(x) + 1
......
......@@ -157,7 +157,7 @@ class Op(MetaObject):
"""
default_output = None
default_output: Optional[int] = None
"""
An `int` that specifies which output `Op.__call__` should return. If
`None`, then all outputs are returned.
......@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text:
return "\n".join(res)
def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text]]:
define_macros = []
undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
......@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
return "\n".join(define_macros), "\n".join(undef_macros)
def get_io_macros(inputs: List[Text], outputs: List[Text]) -> Tuple[List[Text]]:
def get_io_macros(
inputs: List[Text], outputs: List[Text]
) -> Union[Tuple[List[Text]], Tuple[str, str]]:
define_macros = []
undef_macros = []
......@@ -1023,7 +1025,7 @@ class ExternalCOp(COp):
f"No valid section marker was found in file {func_files[i]}"
)
def __get_op_params(self) -> List[Text]:
def __get_op_params(self) -> Union[List[Text], List[Tuple[str, Any]]]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the
......@@ -1130,7 +1132,7 @@ class ExternalCOp(COp):
def get_c_macros(
self, node: Apply, name: Text, check_input: Optional[bool] = None
) -> Tuple[Text]:
) -> Union[Tuple[str], Tuple[str, str]]:
"Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s"
undef_template = "#undef %s"
......
......@@ -2,6 +2,7 @@ import copy
import math
import sys
from io import StringIO
from typing import Dict, Optional
from aesara.configdefaults import config
from aesara.graph import opt
......@@ -192,6 +193,7 @@ class Query:
self.exclude = exclude or OrderedSet()
self.subquery = subquery or {}
self.position_cutoff = position_cutoff
self.name: Optional[str] = None
if extra_optimizations is None:
extra_optimizations = []
self.extra_optimizations = extra_optimizations
......@@ -438,14 +440,18 @@ class LocalGroupDB(DB):
"""
def __init__(
self, apply_all_opts=False, profile=False, local_opt=opt.LocalOptGroup
self,
apply_all_opts: bool = False,
profile: bool = False,
local_opt=opt.LocalOptGroup,
):
super().__init__()
self.failure_callback = None
self.apply_all_opts = apply_all_opts
self.profile = profile
self.__position__ = {}
self.__position__: Dict = {}
self.local_opt = local_opt
self.__name__: str = ""
def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags)
......
......@@ -331,7 +331,7 @@ def unify_walk(a, b, U):
return False
@comm_guard(FreeVariable, ANY_TYPE) # noqa
@comm_guard(FreeVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object).
......@@ -341,7 +341,7 @@ def unify_walk(fv, o, U):
return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE) # noqa
@comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object.
......@@ -353,7 +353,7 @@ def unify_walk(bv, o, U):
return False
@comm_guard(OrVariable, ANY_TYPE) # noqa
@comm_guard(OrVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options.
......@@ -366,7 +366,7 @@ def unify_walk(ov, o, U):
return False
@comm_guard(NotVariable, ANY_TYPE) # noqa
@comm_guard(NotVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options.
......@@ -379,7 +379,7 @@ def unify_walk(nv, o, U):
return U.merge(v, nv)
@comm_guard(FreeVariable, Variable) # noqa
@comm_guard(FreeVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(fv, v, U):
"""
Both variables are unified.
......@@ -389,7 +389,7 @@ def unify_walk(fv, v, U):
return U.merge(v, fv)
@comm_guard(BoundVariable, Variable) # noqa
@comm_guard(BoundVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(bv, v, U):
"""
V is unified to BV.value.
......@@ -398,7 +398,7 @@ def unify_walk(bv, v, U):
return unify_walk(v, bv.value, U)
@comm_guard(OrVariable, OrVariable) # noqa
@comm_guard(OrVariable, OrVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
......@@ -414,7 +414,7 @@ def unify_walk(a, b, U):
return U.merge(v, a, b)
@comm_guard(NotVariable, NotVariable) # noqa
@comm_guard(NotVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
......@@ -425,7 +425,7 @@ def unify_walk(a, b, U):
return U.merge(v, a, b)
@comm_guard(OrVariable, NotVariable) # noqa
@comm_guard(OrVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(o, n, U):
r"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
......@@ -441,7 +441,7 @@ def unify_walk(o, n, U):
return U.merge(v, o, n)
@comm_guard(VariableInList, (list, tuple)) # noqa
@comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
......@@ -452,7 +452,7 @@ def unify_walk(vil, l, U):
return unify_walk(v, ov, U)
@comm_guard((list, tuple), (list, tuple)) # noqa
@comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
......@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U):
return U
@comm_guard(dict, dict) # noqa
@comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
......@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U):
return U
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa
@comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
......@@ -498,7 +498,7 @@ def unify_walk(a, b, U):
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) # noqa
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
......@@ -528,27 +528,27 @@ def unify_merge(a, b, U):
return a
@comm_guard(Variable, ANY_TYPE) # noqa
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U):
return v
@comm_guard(BoundVariable, ANY_TYPE) # noqa
@comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(bv, o, U):
return bv.value
@comm_guard(VariableInList, (list, tuple)) # noqa
@comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(vil, l, U):
return [unify_merge(x, x, U) for x in l]
@comm_guard((list, tuple), (list, tuple)) # noqa
@comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
@comm_guard(dict, dict) # noqa
@comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_merge(d1, d2, U):
d = d1.__class__()
for k1, v1 in d1.items():
......@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U):
return d
@comm_guard(FVar, ANY_TYPE) # noqa
@comm_guard(FVar, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(vs, o, U):
return vs(U)
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa
@comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(a, b, U):
if (
not isinstance(a, Variable)
......@@ -579,7 +579,7 @@ def unify_merge(a, b, U):
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) # noqa
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
......
......@@ -3,6 +3,7 @@ import sys
import traceback
from abc import ABCMeta
from io import StringIO
from typing import List
def simple_extract_stack(f=None, limit=None, skips=None):
......@@ -225,7 +226,7 @@ class MetaType(ABCMeta):
class MetaObject(metaclass=MetaType):
__slots__ = []
__slots__: List = []
def __ne__(self, other):
return not self == other
......
......@@ -8,6 +8,7 @@ import sys
from collections import defaultdict
from copy import copy
from io import StringIO
from typing import Dict
import numpy as np
......@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker):
"""
__cache__ = {}
__cache__: Dict = {}
def __init__(
self, fallback_on_perform=True, allow_gc=None, nice_errors=True, schedule=None
......
......@@ -19,6 +19,7 @@ import textwrap
import time
import warnings
from io import BytesIO, StringIO
from typing import Dict, List, Set
import numpy.distutils
......@@ -631,38 +632,38 @@ class ModuleCache:
"""
dirname = ""
dirname: str = ""
"""
The working directory that is managed by this interface.
"""
module_from_name = {}
module_from_name: Dict = {}
"""
Maps a module filename to the loaded module object.
"""
entry_from_key = {}
entry_from_key: Dict = {}
"""
Maps keys to the filename of a .so/.pyd.
"""
similar_keys = {}
similar_keys: Dict = {}
"""
Maps a part-of-key to all keys that share this same part.
"""
module_hash_to_key_data = {}
module_hash_to_key_data: Dict = {}
"""
Maps a module hash to its corresponding KeyData object.
"""
stats = []
stats: List = []
"""
A list with counters for the number of hits, loads, compiles issued by
module_from_key().
"""
loaded_key_pkl = set()
loaded_key_pkl: Set = set()
"""
Set of all key.pkl files that have been loaded.
......
from abc import abstractmethod
from typing import Callable, Dict, List, Text, Tuple
from typing import Callable, Dict, List, Text, Tuple, Union
from aesara.graph.basic import Apply, Constant
from aesara.graph.utils import MethodNotDefined
......@@ -129,7 +129,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization."""
return []
def c_code_cache_version(self) -> Tuple[int]:
def c_code_cache_version(self) -> Union[Tuple[int], Tuple]:
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an 'unversioned' `Op` that will not be cached
......@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject):
"""
return ""
def c_code_cache_version(self) -> Tuple[int]:
def c_code_cache_version(self) -> Union[Tuple, Tuple[int]]:
"""Return a tuple of integers indicating the version of this type.
An empty tuple indicates an 'unversioned' type that will not
......
......@@ -3,7 +3,7 @@ import sys
import traceback
import warnings
from operator import itemgetter
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union
import numpy as np
......@@ -315,6 +315,10 @@ def raise_with_op(
types = [getattr(ipt, "type", "No type") for ipt in node.inputs]
detailed_err_msg += f"\nInputs types: {types}\n"
shapes: Union[List, str]
strides: Union[List, str]
scalar_values: Union[List, str]
if thunk is not None:
if hasattr(thunk, "inputs"):
shapes = [getattr(ipt[0], "shape", "No shapes") for ipt in thunk.inputs]
......
import os
import pickle
import sys
from typing import Dict
from aesara.configdefaults import config
......@@ -15,8 +16,8 @@ if len(sys.argv) > 1:
else:
dirs = os.listdir(config.compiledir)
dirs = [os.path.join(config.compiledir, d) for d in dirs]
keys = {} # key -> nb seen
mods = {}
keys: Dict = {} # key -> nb seen
mods: Dict = {}
for dir in dirs:
key = None
......@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS:
if v > 1:
print("Duplicate key (%i copies): %s" % (v, pickle.loads(k)))
nbs_keys = {} # nb seen -> now many key
nbs_keys: Dict = {} # nb seen -> now many key
for val in keys.values():
nbs_keys.setdefault(val, 0)
nbs_keys[val] += 1
nbs_mod = {} # nb seen -> how many key
nbs_mod: Dict = {} # nb seen -> how many key
nbs_mod_to_key = {} # nb seen -> keys
more_than_one = 0
for mod, kk in mods.items():
......
import copy
import warnings
from typing import Tuple, Union
import numpy as np
......@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype'
"""
__props__ = ("odtype",)
__props__: Union[Tuple[str], Tuple[str, str]] = ("odtype",)
def __init__(self, odtype):
self.odtype = odtype
......
......@@ -15,7 +15,7 @@ from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Optional
from typing import Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np
......@@ -898,7 +898,7 @@ def upgrade_to_float(*types):
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv = {
conv: Mapping[Type, Type] = {
bool: float32,
int8: float32,
int16: float32,
......@@ -3954,7 +3954,7 @@ class Composite(ScalarOp):
"""
init_param = ("inputs", "outputs")
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
def __str__(self):
if self.name is None:
......@@ -4331,7 +4331,7 @@ class Composite(ScalarOp):
class Compositef32:
# This is a dict of scalar op classes that need special handling
special = {}
special: Dict = {}
def apply(self, fgraph):
mapping = {}
......
......@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en"
import warnings
from functools import singledispatch
from typing import Callable, Optional
from typing import Callable, NoReturn, Optional
def as_tensor_variable(
x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
):
) -> Callable:
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
......@@ -39,7 +39,9 @@ def as_tensor_variable(
@singledispatch
def _as_tensor_variable(x, name: Optional[str], ndim: Optional[int], **kwargs):
def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
) -> NoReturn:
raise NotImplementedError("")
......
......@@ -11,6 +11,7 @@ import warnings
from collections import OrderedDict
from collections.abc import Sequence
from numbers import Number
from typing import Dict
import numpy as np
......@@ -679,7 +680,7 @@ class Rebroadcast(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
c_code_and_version: Dict = {}
check_input = False
__props__ = ("axis",)
......
......@@ -140,6 +140,7 @@ except ImportError:
pass
from functools import reduce
from typing import Tuple, Union
import aesara.scalar
from aesara.compile.mode import optdb
......@@ -506,7 +507,7 @@ class GemmRelated(COp):
"""
__props__ = ()
__props__: Union[Tuple, Tuple[str]] = ()
def c_support_code(self, **kwargs):
# return cblas_header_text()
......
from copy import copy
from typing import Tuple, Union
import numpy as np
......@@ -1297,7 +1298,9 @@ class CAReduce(COp):
"""
__props__ = ("scalar_op", "axis")
__props__: Union[
Tuple[str], Tuple[str, str], Tuple[str, str, str], Tuple[str, str, str, str]
] = ("scalar_op", "axis")
def __init__(self, scalar_op, axis=None):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
......@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce):
"""
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype")
__props__: Union[Tuple[str, str, str], Tuple[str, str, str, str]] = (
"scalar_op",
"axis",
"dtype",
"acc_dtype",
)
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
super().__init__(scalar_op, axis=axis)
......
import logging
from functools import partial
from typing import Tuple, Union
import numpy as np
......@@ -224,7 +225,7 @@ class Eig(Op):
"""
_numop = staticmethod(np.linalg.eig)
__props__ = ()
__props__: Union[Tuple, Tuple[str]] = ()
def make_node(self, x):
x = as_tensor_variable(x)
......
from typing import List
import numpy as np
import aesara
......@@ -26,7 +28,7 @@ class SparseBlockGemv(Op):
__props__ = ("inplace",)
registered_opts = []
registered_opts: List = []
def __init__(self, inplace=False):
self.inplace = inplace
......@@ -147,7 +149,7 @@ class SparseBlockOuter(Op):
__props__ = ("inplace",)
registered_opts = []
registered_opts: List = []
def __init__(self, inplace=False):
self.inplace = inplace
......
import warnings
from typing import Dict
import numpy as np
......@@ -49,7 +50,7 @@ class Shape(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
c_code_and_version: Dict = {}
check_input = False
__props__ = ()
......@@ -143,7 +144,7 @@ class Shape_i(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
c_code_and_version: Dict = {}
check_input = False
......@@ -363,7 +364,7 @@ class SpecifyShape(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
c_code_and_version: Dict = {}
__props__ = ()
_f16_ok = True
......
......@@ -12,6 +12,7 @@ import warnings
from collections import OrderedDict
from collections.abc import Callable
from functools import partial, wraps
from typing import List, Set
__all__ = [
......@@ -29,7 +30,7 @@ __all__ = [
]
__excepthooks = []
__excepthooks: List = []
LOCAL_BITWIDTH = struct.calcsize("P") * 8
......@@ -374,7 +375,7 @@ def apply_across_args(*fns):
class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs = set()
prev_msgs: Set = set()
def filter(self, record):
msg = record.getMessage()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论