提交 ebea2074 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.op and aesara.link.c.op

上级 7ba74079
......@@ -9,11 +9,15 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Text,
Tuple,
Union,
cast,
)
from typing_extensions import Protocol
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, NoParams, Variable
......@@ -30,6 +34,7 @@ from aesara.graph.utils import (
if TYPE_CHECKING:
from aesara.compile.function.types import Function
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
StorageMapType = Dict[Variable, List[Optional[List[Any]]]]
ComputeMapType = Dict[Variable, List[bool]]
......@@ -38,7 +43,22 @@ ParamsInputType = Optional[Tuple[Any]]
PerformMethodType = Callable[
[Apply, List[Any], OutputStorageType, ParamsInputType], None
]
ThunkType = Callable[[PerformMethodType, StorageMapType, ComputeMapType, Apply], Any]
ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
]
class ThunkType(Protocol):
inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]]
lazy: bool
__call__: ThunkCallableType
perform: PerformMethodType
def is_thunk_type(thunk: ThunkCallableType) -> ThunkType:
res = cast(ThunkType, thunk)
return res
def compute_test_value(node: Apply):
......@@ -180,8 +200,8 @@ class Op(MetaObject):
"""
itypes = None
otypes = None
itypes: Optional[Sequence["Type"]] = None
otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None
def make_node(self, *inputs: Variable) -> Apply:
......@@ -276,7 +296,7 @@ class Op(MetaObject):
if self.default_output is not None:
rval = node.outputs[self.default_output]
if return_list:
rval = [rval]
return [rval]
return rval
else:
if return_list:
......@@ -382,17 +402,17 @@ class Op(MetaObject):
Parameters
----------
node : Apply
node
The symbolic `Apply` node that represents this computation.
inputs : Sequence
inputs
Immutable sequence of non-symbolic/numeric inputs. These
are the values of each `Variable` in :attr:`node.inputs`.
output_storage : list of list
output_storage
List of mutable single-element lists (do not change the length of
these lists). Each sub-list corresponds to value of each
`Variable` in :attr:`node.outputs`. The primary purpose of this method
is to set the values of these sub-lists.
params : tuple
params
A tuple containing the values of each entry in :attr:`Op.__props__`.
Notes
......@@ -492,7 +512,10 @@ class Op(MetaObject):
if params is NoParams:
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
@is_thunk_type
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
......@@ -501,6 +524,7 @@ class Op(MetaObject):
else:
params_val = node.params_type.filter(params)
@is_thunk_type
def rval(
p=p,
i=node_input_storage,
......@@ -515,7 +539,7 @@ class Op(MetaObject):
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
setattr(rval, "perform", p)
rval.lazy = False
return rval
......@@ -604,7 +628,7 @@ class HasInnerGraph:
"""The inner function's outputs."""
def get_test_value(v: Variable) -> Any:
def get_test_value(v: Any) -> Any:
"""Get the test value for `v`.
If input `v` is not already a variable, it is turned into one by calling
......
......@@ -554,7 +554,7 @@ class ParamsType(CType):
else self.__const_to_enum[alias][alias]
)
def get_params(self, *objects, **kwargs):
def get_params(self, *objects, **kwargs) -> Params:
"""
Convenient method to extract fields values from a list of Python objects and key-value args,
and wrap them into a :class:`Params` object compatible with current ParamsType.
......
......@@ -3,7 +3,9 @@ import os
import re
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Dict,
......@@ -14,18 +16,34 @@ from typing import (
Text,
Tuple,
Union,
cast,
)
import numpy as np
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import ComputeMapType, Op, StorageMapType, ThunkType
from aesara.graph.params_type import ParamsType
from aesara.graph.type import HasDataType
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import CLinkerOp
if TYPE_CHECKING:
from aesara.link.c.basic import _CThunk
class CThunkWrapperType(ThunkType):
thunk: "_CThunk"
cthunk: ThunkType
def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType:
res = cast(CThunkWrapperType, thunk)
return res
class COp(Op, CLinkerOp):
"""An `Op` with a C implementation."""
......@@ -34,8 +52,8 @@ class COp(Op, CLinkerOp):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: Collection[Apply],
) -> ThunkType:
no_recycling: Collection[Variable],
) -> CThunkWrapperType:
"""Create a thunk for a C implementation.
Like :meth:`Op.make_thunk`, but will only try to make a C thunk.
......@@ -80,6 +98,7 @@ class COp(Op, CLinkerOp):
)
thunk, node_input_filters, node_output_filters = outputs
@is_cthunk_wrapper_type
def rval():
thunk()
for o in node.outputs:
......@@ -520,17 +539,18 @@ class ExternalCOp(COp):
# Generate dtype macros
for i, v in enumerate(variables):
if not hasattr(v, "dtype"):
if not isinstance(v.type, HasDataType):
continue
vname = variable_names[i]
macro_name = "DTYPE_" + vname
macro_value = "npy_" + v.dtype
macro_value = "npy_" + v.type.dtype
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
d = np.dtype(v.dtype)
d = np.dtype(v.type.dtype)
macro_name = "TYPENUM_" + vname
macro_value = d.num
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.fg]
ignore_errors = True
check_untyped_defs = False
......@@ -151,10 +147,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.c.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论