提交 65c410b0 authored 作者: LegrandNico's avatar LegrandNico 提交者: Thomas Wiecki

Fix type hint errors (n errors: 386->273)

上级 22c8c46d
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import sys import sys
import warnings import warnings
from functools import reduce from functools import reduce
from typing import Optional
import numpy as np import numpy as np
...@@ -1904,8 +1905,8 @@ class GpuDnnPoolBase(DnnBase): ...@@ -1904,8 +1905,8 @@ class GpuDnnPoolBase(DnnBase):
""" """
# c_file and c_function must be defined in sub-classes. # c_file and c_function must be defined in sub-classes.
c_file = None c_file: Optional[str] = None
c_function = None c_function: Optional[str] = None
_f16_ok = True _f16_ok = True
__props__ = ("mode",) __props__ = ("mode",)
......
...@@ -12,6 +12,7 @@ that satisfies the constraints. That's useful for pattern matching. ...@@ -12,6 +12,7 @@ that satisfies the constraints. That's useful for pattern matching.
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import Dict
class Keyword: class Keyword:
...@@ -218,7 +219,7 @@ class VariableInList: # not a subclass of Variable ...@@ -218,7 +219,7 @@ class VariableInList: # not a subclass of Variable
self.variable = variable self.variable = variable
_all = {} _all: Dict = {}
def var_lookup(vartype, name, *args, **kwargs): def var_lookup(vartype, name, *args, **kwargs):
......
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List, Text, Tuple from typing import Callable, Dict, List, Text, Tuple
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
...@@ -568,25 +568,25 @@ class HideC(CLinkerOp): ...@@ -568,25 +568,25 @@ class HideC(CLinkerOp):
def __hide(*args): def __hide(*args):
raise MethodNotDefined() raise MethodNotDefined()
c_code = __hide c_code: Callable = __hide
c_code_cleanup = __hide c_code_cleanup: Callable = __hide
c_headers = __hide c_headers: Callable = __hide
c_header_dirs = __hide c_header_dirs: Callable = __hide
c_libraries = __hide c_libraries: Callable = __hide
c_lib_dirs = __hide c_lib_dirs: Callable = __hide
c_support_code = __hide c_support_code: Callable = __hide
c_support_code_apply = __hide c_support_code_apply: Callable = __hide
c_compile_args = __hide c_compile_args: Callable = __hide
c_no_compile_args = __hide c_no_compile_args: Callable = __hide
c_init_code = __hide c_init_code: Callable = __hide
c_init_code_apply = __hide c_init_code_apply: Callable = __hide
c_init_code_struct = __hide c_init_code_struct: Callable = __hide
c_support_code_struct = __hide c_support_code_struct: Callable = __hide
c_cleanup_code_struct = __hide c_cleanup_code_struct: Callable = __hide
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
......
import io import io
import sys import sys
import traceback import traceback
import typing
import warnings import warnings
from operator import itemgetter from operator import itemgetter
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple
import numpy as np import numpy as np
...@@ -15,11 +15,11 @@ from aesara.graph.fg import FunctionGraph ...@@ -15,11 +15,11 @@ from aesara.graph.fg import FunctionGraph
def map_storage( def map_storage(
fgraph: FunctionGraph, fgraph: FunctionGraph,
order: typing.Iterable[Apply], order: Iterable[Apply],
input_storage: typing.Optional[typing.List], input_storage: Optional[List],
output_storage: typing.Optional[typing.List], output_storage: Optional[List],
storage_map: typing.Dict = None, storage_map: Dict = None,
) -> typing.Tuple[typing.List, typing.List, typing.Dict]: ) -> Tuple[List, List, Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes. """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes. :param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
...@@ -132,7 +132,7 @@ def streamline( ...@@ -132,7 +132,7 @@ def streamline(
post_thunk_old_storage=None, post_thunk_old_storage=None,
no_recycling=None, no_recycling=None,
nice_errors=True, nice_errors=True,
) -> typing.Callable[[], typing.NoReturn]: ) -> Callable[[], NoReturn]:
""" """
WRITEME WRITEME
...@@ -210,7 +210,7 @@ def streamline( ...@@ -210,7 +210,7 @@ def streamline(
return f return f
def gc_helper(node_list: typing.List[Apply]): def gc_helper(node_list: List[Apply]):
""" """
Return the set of Variable instances which are computed by node_list. Return the set of Variable instances which are computed by node_list.
Parameters Parameters
...@@ -390,11 +390,11 @@ def raise_with_op( ...@@ -390,11 +390,11 @@ def raise_with_op(
for item in fgraph.inputs for item in fgraph.inputs
if not isinstance(item, aesara.compile.SharedVariable) if not isinstance(item, aesara.compile.SharedVariable)
] ]
storage_map_list = [] storage_map_list: List = []
total_size = 0 total_size = 0
total_size_inputs = 0 total_size_inputs = 0
for k in storage_map: for k in storage_map:
storage_map_item = [] storage_map_item: List = []
# storage_map_item[0]: the variable # storage_map_item[0]: the variable
storage_map_item.append(str(k)) storage_map_item.append(str(k))
......
...@@ -15,6 +15,7 @@ from collections.abc import Callable ...@@ -15,6 +15,7 @@ from collections.abc import Callable
from copy import copy from copy import copy
from itertools import chain from itertools import chain
from textwrap import dedent from textwrap import dedent
from typing import Optional
import numpy as np import numpy as np
...@@ -1215,8 +1216,8 @@ class ScalarOp(COp): ...@@ -1215,8 +1216,8 @@ class ScalarOp(COp):
class UnaryScalarOp(ScalarOp): class UnaryScalarOp(ScalarOp):
nin = 1 nin = 1
amd_float32 = None amd_float32: Optional[str] = None
amd_float64 = None amd_float64: Optional[str] = None
def c_code_contiguous(self, node, name, inputs, outputs, sub): def c_code_contiguous(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -1261,9 +1262,9 @@ class BinaryScalarOp(ScalarOp): ...@@ -1261,9 +1262,9 @@ class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields: # One may define in subclasses the following fields:
# - `commutative`: whether op(a, b) == op(b, a) # - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c)) # - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative = None commutative: Optional[bool] = None
associative = None associative: Optional[bool] = None
identity = None identity: Optional[bool] = None
""" """
For an associative operation, the identity object corresponds to the neutral For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication, element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
......
...@@ -3253,7 +3253,7 @@ def structured_add(x): ...@@ -3253,7 +3253,7 @@ def structured_add(x):
# Sparse operation (map 0 to 0) # Sparse operation (map 0 to 0)
@structured_monoid(sin) @structured_monoid(sin) # type: ignore[no-redef]
def sin(x): def sin(x):
""" """
Elemwise sinus of `x`. Elemwise sinus of `x`.
...@@ -3262,7 +3262,7 @@ def sin(x): ...@@ -3262,7 +3262,7 @@ def sin(x):
# see decorator for function body # see decorator for function body
@structured_monoid(tan) @structured_monoid(tan) # type: ignore[no-redef]
def tan(x): def tan(x):
""" """
Elemwise tan of `x`. Elemwise tan of `x`.
...@@ -3271,7 +3271,7 @@ def tan(x): ...@@ -3271,7 +3271,7 @@ def tan(x):
# see decorator for function body # see decorator for function body
@structured_monoid(arcsin) @structured_monoid(arcsin) # type: ignore[no-redef]
def arcsin(x): def arcsin(x):
""" """
Elemwise arcsinus of `x`. Elemwise arcsinus of `x`.
...@@ -3280,7 +3280,7 @@ def arcsin(x): ...@@ -3280,7 +3280,7 @@ def arcsin(x):
# see decorator for function body # see decorator for function body
@structured_monoid(arctan) @structured_monoid(arctan) # type: ignore[no-redef]
def arctan(x): def arctan(x):
""" """
Elemwise arctan of `x`. Elemwise arctan of `x`.
...@@ -3289,7 +3289,7 @@ def arctan(x): ...@@ -3289,7 +3289,7 @@ def arctan(x):
# see decorator for function body # see decorator for function body
@structured_monoid(sinh) @structured_monoid(sinh) # type: ignore[no-redef]
def sinh(x): def sinh(x):
""" """
Elemwise sinh of `x`. Elemwise sinh of `x`.
...@@ -3298,7 +3298,7 @@ def sinh(x): ...@@ -3298,7 +3298,7 @@ def sinh(x):
# see decorator for function body # see decorator for function body
@structured_monoid(arcsinh) @structured_monoid(arcsinh) # type: ignore[no-redef]
def arcsinh(x): def arcsinh(x):
""" """
Elemwise arcsinh of `x`. Elemwise arcsinh of `x`.
...@@ -3307,7 +3307,7 @@ def arcsinh(x): ...@@ -3307,7 +3307,7 @@ def arcsinh(x):
# see decorator for function body # see decorator for function body
@structured_monoid(tanh) @structured_monoid(tanh) # type: ignore[no-redef]
def tanh(x): def tanh(x):
""" """
Elemwise tanh of `x`. Elemwise tanh of `x`.
...@@ -3316,7 +3316,7 @@ def tanh(x): ...@@ -3316,7 +3316,7 @@ def tanh(x):
# see decorator for function body # see decorator for function body
@structured_monoid(arctanh) @structured_monoid(arctanh) # type: ignore[no-redef]
def arctanh(x): def arctanh(x):
""" """
Elemwise arctanh of `x`. Elemwise arctanh of `x`.
...@@ -3339,7 +3339,7 @@ def rint(x): ...@@ -3339,7 +3339,7 @@ def rint(x):
rint.__name__ = "rint" rint.__name__ = "rint"
@structured_monoid(sgn) @structured_monoid(sgn) # type: ignore[no-redef]
def sgn(x): def sgn(x):
""" """
Elemwise signe of `x`. Elemwise signe of `x`.
...@@ -3348,7 +3348,7 @@ def sgn(x): ...@@ -3348,7 +3348,7 @@ def sgn(x):
# see decorator for function body # see decorator for function body
@structured_monoid(ceil) @structured_monoid(ceil) # type: ignore[no-redef]
def ceil(x): def ceil(x):
""" """
Elemwise ceiling of `x`. Elemwise ceiling of `x`.
...@@ -3357,7 +3357,7 @@ def ceil(x): ...@@ -3357,7 +3357,7 @@ def ceil(x):
# see decorator for function body # see decorator for function body
@structured_monoid(floor) @structured_monoid(floor) # type: ignore[no-redef]
def floor(x): def floor(x):
""" """
Elemwise floor of `x`. Elemwise floor of `x`.
...@@ -3366,7 +3366,7 @@ def floor(x): ...@@ -3366,7 +3366,7 @@ def floor(x):
# see decorator for function body # see decorator for function body
@structured_monoid(log1p) @structured_monoid(log1p) # type: ignore[no-redef]
def log1p(x): def log1p(x):
""" """
Elemwise log(1 + `x`). Elemwise log(1 + `x`).
...@@ -3375,7 +3375,7 @@ def log1p(x): ...@@ -3375,7 +3375,7 @@ def log1p(x):
# see decorator for function body # see decorator for function body
@structured_monoid(expm1) @structured_monoid(expm1) # type: ignore[no-redef]
def expm1(x): def expm1(x):
""" """
Elemwise e^`x` - 1. Elemwise e^`x` - 1.
...@@ -3384,7 +3384,7 @@ def expm1(x): ...@@ -3384,7 +3384,7 @@ def expm1(x):
# see decorator for function body # see decorator for function body
@structured_monoid(deg2rad) @structured_monoid(deg2rad) # type: ignore[no-redef]
def deg2rad(x): def deg2rad(x):
""" """
Elemwise degree to radian. Elemwise degree to radian.
...@@ -3393,7 +3393,7 @@ def deg2rad(x): ...@@ -3393,7 +3393,7 @@ def deg2rad(x):
# see decorator for function body # see decorator for function body
@structured_monoid(rad2deg) @structured_monoid(rad2deg) # type: ignore[no-redef]
def rad2deg(x): def rad2deg(x):
""" """
Elemwise radian to degree. Elemwise radian to degree.
...@@ -3402,7 +3402,7 @@ def rad2deg(x): ...@@ -3402,7 +3402,7 @@ def rad2deg(x):
# see decorator for function body # see decorator for function body
@structured_monoid(trunc) @structured_monoid(trunc) # type: ignore[no-redef]
def trunc(x): def trunc(x):
""" """
Elemwise truncature. Elemwise truncature.
...@@ -3411,7 +3411,7 @@ def trunc(x): ...@@ -3411,7 +3411,7 @@ def trunc(x):
# see decorator for function body # see decorator for function body
@structured_monoid(sqr) @structured_monoid(sqr) # type: ignore[no-redef]
def sqr(x): def sqr(x):
""" """
Elemwise `x` * `x`. Elemwise `x` * `x`.
...@@ -3420,7 +3420,7 @@ def sqr(x): ...@@ -3420,7 +3420,7 @@ def sqr(x):
# see decorator for function body # see decorator for function body
@structured_monoid(sqrt) @structured_monoid(sqrt) # type: ignore[no-redef]
def sqrt(x): def sqrt(x):
""" """
Elemwise square root of `x`. Elemwise square root of `x`.
...@@ -3429,7 +3429,7 @@ def sqrt(x): ...@@ -3429,7 +3429,7 @@ def sqrt(x):
# see decorator for function body # see decorator for function body
@structured_monoid(conj) @structured_monoid(conj) # type: ignore[no-redef]
def conj(x): def conj(x):
""" """
Elemwise complex conjugate of `x`. Elemwise complex conjugate of `x`.
......
...@@ -5,9 +5,12 @@ __docformat__ = "restructuredtext en" ...@@ -5,9 +5,12 @@ __docformat__ = "restructuredtext en"
import warnings import warnings
from functools import singledispatch from functools import singledispatch
from typing import Callable, Optional
def as_tensor_variable(x, name=None, ndim=None, **kwargs): def as_tensor_variable(
x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
):
"""Convert `x` into the appropriate `TensorType`. """Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to This function is often used by `make_node` methods of `Op` subclasses to
...@@ -36,7 +39,7 @@ def as_tensor_variable(x, name=None, ndim=None, **kwargs): ...@@ -36,7 +39,7 @@ def as_tensor_variable(x, name=None, ndim=None, **kwargs):
@singledispatch @singledispatch
def _as_tensor_variable(x, name, ndim, **kwargs): def _as_tensor_variable(x, name: Optional[str], ndim: Optional[int], **kwargs):
raise NotImplementedError("") raise NotImplementedError("")
......
import logging import logging
import os import os
from typing import Optional
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -56,7 +57,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -56,7 +57,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
"unshared", "unshared",
) )
_direction = None _direction: Optional[str] = None
params_type = ParamsType( params_type = ParamsType(
direction=EnumList( direction=EnumList(
......
import logging import logging
import os import os
from typing import Optional
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -48,7 +49,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): ...@@ -48,7 +49,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
check_broadcast = False check_broadcast = False
__props__ = ("border_mode", "subsample", "filter_dilation", "num_groups") __props__ = ("border_mode", "subsample", "filter_dilation", "num_groups")
_direction = None _direction: Optional[str] = None
params_type = ParamsType( params_type = ParamsType(
direction=EnumList( direction=EnumList(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论