提交 ebea2074 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.op and aesara.link.c.op

上级 7ba74079
...@@ -9,11 +9,15 @@ from typing import ( ...@@ -9,11 +9,15 @@ from typing import (
Dict, Dict,
List, List,
Optional, Optional,
Sequence,
Text, Text,
Tuple, Tuple,
Union, Union,
cast,
) )
from typing_extensions import Protocol
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, NoParams, Variable from aesara.graph.basic import Apply, NoParams, Variable
...@@ -30,6 +34,7 @@ from aesara.graph.utils import ( ...@@ -30,6 +34,7 @@ from aesara.graph.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from aesara.compile.function.types import Function from aesara.compile.function.types import Function
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
StorageMapType = Dict[Variable, List[Optional[List[Any]]]] StorageMapType = Dict[Variable, List[Optional[List[Any]]]]
ComputeMapType = Dict[Variable, List[bool]] ComputeMapType = Dict[Variable, List[bool]]
...@@ -38,7 +43,22 @@ ParamsInputType = Optional[Tuple[Any]] ...@@ -38,7 +43,22 @@ ParamsInputType = Optional[Tuple[Any]]
PerformMethodType = Callable[ PerformMethodType = Callable[
[Apply, List[Any], OutputStorageType, ParamsInputType], None [Apply, List[Any], OutputStorageType, ParamsInputType], None
] ]
ThunkType = Callable[[PerformMethodType, StorageMapType, ComputeMapType, Apply], Any] ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
]
class ThunkType(Protocol):
inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]]
lazy: bool
__call__: ThunkCallableType
perform: PerformMethodType
def is_thunk_type(thunk: ThunkCallableType) -> ThunkType:
res = cast(ThunkType, thunk)
return res
def compute_test_value(node: Apply): def compute_test_value(node: Apply):
...@@ -180,8 +200,8 @@ class Op(MetaObject): ...@@ -180,8 +200,8 @@ class Op(MetaObject):
""" """
itypes = None itypes: Optional[Sequence["Type"]] = None
otypes = None otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None params_type: Optional[ParamsType] = None
def make_node(self, *inputs: Variable) -> Apply: def make_node(self, *inputs: Variable) -> Apply:
...@@ -276,7 +296,7 @@ class Op(MetaObject): ...@@ -276,7 +296,7 @@ class Op(MetaObject):
if self.default_output is not None: if self.default_output is not None:
rval = node.outputs[self.default_output] rval = node.outputs[self.default_output]
if return_list: if return_list:
rval = [rval] return [rval]
return rval return rval
else: else:
if return_list: if return_list:
...@@ -382,17 +402,17 @@ class Op(MetaObject): ...@@ -382,17 +402,17 @@ class Op(MetaObject):
Parameters Parameters
---------- ----------
node : Apply node
The symbolic `Apply` node that represents this computation. The symbolic `Apply` node that represents this computation.
inputs : Sequence inputs
Immutable sequence of non-symbolic/numeric inputs. These Immutable sequence of non-symbolic/numeric inputs. These
are the values of each `Variable` in :attr:`node.inputs`. are the values of each `Variable` in :attr:`node.inputs`.
output_storage : list of list output_storage
List of mutable single-element lists (do not change the length of List of mutable single-element lists (do not change the length of
these lists). Each sub-list corresponds to value of each these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method `Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists. is to set the values of these sub-lists.
params : tuple params
A tuple containing the values of each entry in :attr:`Op.__props__`. A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes Notes
...@@ -492,7 +512,10 @@ class Op(MetaObject): ...@@ -492,7 +512,10 @@ class Op(MetaObject):
if params is NoParams: if params is NoParams:
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): @is_thunk_type
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
for o in node.outputs: for o in node.outputs:
compute_map[o][0] = True compute_map[o][0] = True
...@@ -501,6 +524,7 @@ class Op(MetaObject): ...@@ -501,6 +524,7 @@ class Op(MetaObject):
else: else:
params_val = node.params_type.filter(params) params_val = node.params_type.filter(params)
@is_thunk_type
def rval( def rval(
p=p, p=p,
i=node_input_storage, i=node_input_storage,
...@@ -515,7 +539,7 @@ class Op(MetaObject): ...@@ -515,7 +539,7 @@ class Op(MetaObject):
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
rval.perform = p setattr(rval, "perform", p)
rval.lazy = False rval.lazy = False
return rval return rval
...@@ -604,7 +628,7 @@ class HasInnerGraph: ...@@ -604,7 +628,7 @@ class HasInnerGraph:
"""The inner function's outputs.""" """The inner function's outputs."""
def get_test_value(v: Variable) -> Any: def get_test_value(v: Any) -> Any:
"""Get the test value for `v`. """Get the test value for `v`.
If input `v` is not already a variable, it is turned into one by calling If input `v` is not already a variable, it is turned into one by calling
......
...@@ -554,7 +554,7 @@ class ParamsType(CType): ...@@ -554,7 +554,7 @@ class ParamsType(CType):
else self.__const_to_enum[alias][alias] else self.__const_to_enum[alias][alias]
) )
def get_params(self, *objects, **kwargs): def get_params(self, *objects, **kwargs) -> Params:
""" """
Convenient method to extract fields values from a list of Python objects and key-value args, Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType. and wrap them into a :class:`Params` object compatible with current ParamsType.
......
...@@ -3,7 +3,9 @@ import os ...@@ -3,7 +3,9 @@ import os
import re import re
import warnings import warnings
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable,
ClassVar, ClassVar,
Collection, Collection,
Dict, Dict,
...@@ -14,18 +16,34 @@ from typing import ( ...@@ -14,18 +16,34 @@ from typing import (
Text, Text,
Tuple, Tuple,
Union, Union,
cast,
) )
import numpy as np import numpy as np
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply, Variable
from aesara.graph.op import ComputeMapType, Op, StorageMapType, ThunkType from aesara.graph.op import ComputeMapType, Op, StorageMapType, ThunkType
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import HasDataType
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import CLinkerOp from aesara.link.c.interface import CLinkerOp
if TYPE_CHECKING:
from aesara.link.c.basic import _CThunk
class CThunkWrapperType(ThunkType):
thunk: "_CThunk"
cthunk: ThunkType
def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType:
res = cast(CThunkWrapperType, thunk)
return res
class COp(Op, CLinkerOp): class COp(Op, CLinkerOp):
"""An `Op` with a C implementation.""" """An `Op` with a C implementation."""
...@@ -34,8 +52,8 @@ class COp(Op, CLinkerOp): ...@@ -34,8 +52,8 @@ class COp(Op, CLinkerOp):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: Collection[Apply], no_recycling: Collection[Variable],
) -> ThunkType: ) -> CThunkWrapperType:
"""Create a thunk for a C implementation. """Create a thunk for a C implementation.
Like :meth:`Op.make_thunk`, but will only try to make a C thunk. Like :meth:`Op.make_thunk`, but will only try to make a C thunk.
...@@ -80,6 +98,7 @@ class COp(Op, CLinkerOp): ...@@ -80,6 +98,7 @@ class COp(Op, CLinkerOp):
) )
thunk, node_input_filters, node_output_filters = outputs thunk, node_input_filters, node_output_filters = outputs
@is_cthunk_wrapper_type
def rval(): def rval():
thunk() thunk()
for o in node.outputs: for o in node.outputs:
...@@ -520,17 +539,18 @@ class ExternalCOp(COp): ...@@ -520,17 +539,18 @@ class ExternalCOp(COp):
# Generate dtype macros # Generate dtype macros
for i, v in enumerate(variables): for i, v in enumerate(variables):
if not hasattr(v, "dtype"): if not isinstance(v.type, HasDataType):
continue continue
vname = variable_names[i] vname = variable_names[i]
macro_name = "DTYPE_" + vname macro_name = "DTYPE_" + vname
macro_value = "npy_" + v.dtype macro_value = "npy_" + v.type.dtype
define_macros.append(define_template % (macro_name, macro_value)) define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name) undef_macros.append(undef_template % macro_name)
d = np.dtype(v.dtype) d = np.dtype(v.type.dtype)
macro_name = "TYPENUM_" + vname macro_name = "TYPENUM_" + vname
macro_value = d.num macro_value = d.num
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.graph.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.fg] [mypy-aesara.graph.fg]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
...@@ -151,10 +147,6 @@ check_untyped_defs = False ...@@ -151,10 +147,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.link.c.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils] [mypy-aesara.link.utils]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论