提交 65c410b0 authored 作者: LegrandNico's avatar LegrandNico 提交者: Thomas Wiecki

Fix type hint errors (n errors: 386->273)

上级 22c8c46d
......@@ -3,6 +3,7 @@ import os
import sys
import warnings
from functools import reduce
from typing import Optional
import numpy as np
......@@ -1904,8 +1905,8 @@ class GpuDnnPoolBase(DnnBase):
"""
# c_file and c_function must be defined in sub-classes.
c_file = None
c_function = None
c_file: Optional[str] = None
c_function: Optional[str] = None
_f16_ok = True
__props__ = ("mode",)
......
......@@ -12,6 +12,7 @@ that satisfies the constraints. That's useful for pattern matching.
from copy import copy
from functools import partial
from typing import Dict
class Keyword:
......@@ -218,7 +219,7 @@ class VariableInList: # not a subclass of Variable
self.variable = variable
_all = {}
_all: Dict = {}
def var_lookup(vartype, name, *args, **kwargs):
......
from abc import abstractmethod
from typing import Dict, List, Text, Tuple
from typing import Callable, Dict, List, Text, Tuple
from aesara.graph.basic import Apply, Constant
from aesara.graph.utils import MethodNotDefined
......@@ -568,25 +568,25 @@ class HideC(CLinkerOp):
def __hide(*args):
raise MethodNotDefined()
c_code = __hide
c_code_cleanup = __hide
c_code: Callable = __hide
c_code_cleanup: Callable = __hide
c_headers = __hide
c_header_dirs = __hide
c_libraries = __hide
c_lib_dirs = __hide
c_headers: Callable = __hide
c_header_dirs: Callable = __hide
c_libraries: Callable = __hide
c_lib_dirs: Callable = __hide
c_support_code = __hide
c_support_code_apply = __hide
c_support_code: Callable = __hide
c_support_code_apply: Callable = __hide
c_compile_args = __hide
c_no_compile_args = __hide
c_init_code = __hide
c_init_code_apply = __hide
c_compile_args: Callable = __hide
c_no_compile_args: Callable = __hide
c_init_code: Callable = __hide
c_init_code_apply: Callable = __hide
c_init_code_struct = __hide
c_support_code_struct = __hide
c_cleanup_code_struct = __hide
c_init_code_struct: Callable = __hide
c_support_code_struct: Callable = __hide
c_cleanup_code_struct: Callable = __hide
def c_code_cache_version(self):
return ()
......
import io
import sys
import traceback
import typing
import warnings
from operator import itemgetter
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple
import numpy as np
......@@ -15,11 +15,11 @@ from aesara.graph.fg import FunctionGraph
def map_storage(
fgraph: FunctionGraph,
order: typing.Iterable[Apply],
input_storage: typing.Optional[typing.List],
output_storage: typing.Optional[typing.List],
storage_map: typing.Dict = None,
) -> typing.Tuple[typing.List, typing.List, typing.Dict]:
order: Iterable[Apply],
input_storage: Optional[List],
output_storage: Optional[List],
storage_map: Dict = None,
) -> Tuple[List, List, Dict]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
......@@ -132,7 +132,7 @@ def streamline(
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> typing.Callable[[], typing.NoReturn]:
) -> Callable[[], NoReturn]:
"""
WRITEME
......@@ -210,7 +210,7 @@ def streamline(
return f
def gc_helper(node_list: typing.List[Apply]):
def gc_helper(node_list: List[Apply]):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
......@@ -390,11 +390,11 @@ def raise_with_op(
for item in fgraph.inputs
if not isinstance(item, aesara.compile.SharedVariable)
]
storage_map_list = []
storage_map_list: List = []
total_size = 0
total_size_inputs = 0
for k in storage_map:
storage_map_item = []
storage_map_item: List = []
# storage_map_item[0]: the variable
storage_map_item.append(str(k))
......
......@@ -15,6 +15,7 @@ from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Optional
import numpy as np
......@@ -1215,8 +1216,8 @@ class ScalarOp(COp):
class UnaryScalarOp(ScalarOp):
nin = 1
amd_float32 = None
amd_float64 = None
amd_float32: Optional[str] = None
amd_float64: Optional[str] = None
def c_code_contiguous(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -1261,9 +1262,9 @@ class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields:
# - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative = None
associative = None
identity = None
commutative: Optional[bool] = None
associative: Optional[bool] = None
identity: Optional[bool] = None
"""
For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
......
......@@ -3253,7 +3253,7 @@ def structured_add(x):
# Sparse operation (map 0 to 0)
@structured_monoid(sin)
@structured_monoid(sin) # type: ignore[no-redef]
def sin(x):
"""
Elemwise sinus of `x`.
......@@ -3262,7 +3262,7 @@ def sin(x):
# see decorator for function body
@structured_monoid(tan)
@structured_monoid(tan) # type: ignore[no-redef]
def tan(x):
"""
Elemwise tan of `x`.
......@@ -3271,7 +3271,7 @@ def tan(x):
# see decorator for function body
@structured_monoid(arcsin)
@structured_monoid(arcsin) # type: ignore[no-redef]
def arcsin(x):
"""
Elemwise arcsinus of `x`.
......@@ -3280,7 +3280,7 @@ def arcsin(x):
# see decorator for function body
@structured_monoid(arctan)
@structured_monoid(arctan) # type: ignore[no-redef]
def arctan(x):
"""
Elemwise arctan of `x`.
......@@ -3289,7 +3289,7 @@ def arctan(x):
# see decorator for function body
@structured_monoid(sinh)
@structured_monoid(sinh) # type: ignore[no-redef]
def sinh(x):
"""
Elemwise sinh of `x`.
......@@ -3298,7 +3298,7 @@ def sinh(x):
# see decorator for function body
@structured_monoid(arcsinh)
@structured_monoid(arcsinh) # type: ignore[no-redef]
def arcsinh(x):
"""
Elemwise arcsinh of `x`.
......@@ -3307,7 +3307,7 @@ def arcsinh(x):
# see decorator for function body
@structured_monoid(tanh)
@structured_monoid(tanh) # type: ignore[no-redef]
def tanh(x):
"""
Elemwise tanh of `x`.
......@@ -3316,7 +3316,7 @@ def tanh(x):
# see decorator for function body
@structured_monoid(arctanh)
@structured_monoid(arctanh) # type: ignore[no-redef]
def arctanh(x):
"""
Elemwise arctanh of `x`.
......@@ -3339,7 +3339,7 @@ def rint(x):
rint.__name__ = "rint"
@structured_monoid(sgn)
@structured_monoid(sgn) # type: ignore[no-redef]
def sgn(x):
"""
Elemwise signe of `x`.
......@@ -3348,7 +3348,7 @@ def sgn(x):
# see decorator for function body
@structured_monoid(ceil)
@structured_monoid(ceil) # type: ignore[no-redef]
def ceil(x):
"""
Elemwise ceiling of `x`.
......@@ -3357,7 +3357,7 @@ def ceil(x):
# see decorator for function body
@structured_monoid(floor)
@structured_monoid(floor) # type: ignore[no-redef]
def floor(x):
"""
Elemwise floor of `x`.
......@@ -3366,7 +3366,7 @@ def floor(x):
# see decorator for function body
@structured_monoid(log1p)
@structured_monoid(log1p) # type: ignore[no-redef]
def log1p(x):
"""
Elemwise log(1 + `x`).
......@@ -3375,7 +3375,7 @@ def log1p(x):
# see decorator for function body
@structured_monoid(expm1)
@structured_monoid(expm1) # type: ignore[no-redef]
def expm1(x):
"""
Elemwise e^`x` - 1.
......@@ -3384,7 +3384,7 @@ def expm1(x):
# see decorator for function body
@structured_monoid(deg2rad)
@structured_monoid(deg2rad) # type: ignore[no-redef]
def deg2rad(x):
"""
Elemwise degree to radian.
......@@ -3393,7 +3393,7 @@ def deg2rad(x):
# see decorator for function body
@structured_monoid(rad2deg)
@structured_monoid(rad2deg) # type: ignore[no-redef]
def rad2deg(x):
"""
Elemwise radian to degree.
......@@ -3402,7 +3402,7 @@ def rad2deg(x):
# see decorator for function body
@structured_monoid(trunc)
@structured_monoid(trunc) # type: ignore[no-redef]
def trunc(x):
"""
Elemwise truncature.
......@@ -3411,7 +3411,7 @@ def trunc(x):
# see decorator for function body
@structured_monoid(sqr)
@structured_monoid(sqr) # type: ignore[no-redef]
def sqr(x):
"""
Elemwise `x` * `x`.
......@@ -3420,7 +3420,7 @@ def sqr(x):
# see decorator for function body
@structured_monoid(sqrt)
@structured_monoid(sqrt) # type: ignore[no-redef]
def sqrt(x):
"""
Elemwise square root of `x`.
......@@ -3429,7 +3429,7 @@ def sqrt(x):
# see decorator for function body
@structured_monoid(conj)
@structured_monoid(conj) # type: ignore[no-redef]
def conj(x):
"""
Elemwise complex conjugate of `x`.
......
......@@ -5,9 +5,12 @@ __docformat__ = "restructuredtext en"
import warnings
from functools import singledispatch
from typing import Callable, Optional
def as_tensor_variable(x, name=None, ndim=None, **kwargs):
def as_tensor_variable(
x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
):
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
......@@ -36,7 +39,7 @@ def as_tensor_variable(x, name=None, ndim=None, **kwargs):
@singledispatch
def _as_tensor_variable(x, name, ndim, **kwargs):
def _as_tensor_variable(x, name: Optional[str], ndim: Optional[int], **kwargs):
raise NotImplementedError("")
......
import logging
import os
from typing import Optional
import aesara
from aesara.configdefaults import config
......@@ -56,7 +57,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
"unshared",
)
_direction = None
_direction: Optional[str] = None
params_type = ParamsType(
direction=EnumList(
......
import logging
import os
from typing import Optional
import aesara
from aesara.configdefaults import config
......@@ -48,7 +49,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
check_broadcast = False
__props__ = ("border_mode", "subsample", "filter_dilation", "num_groups")
_direction = None
_direction: Optional[str] = None
params_type = ParamsType(
direction=EnumList(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论