提交 b4312087 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move basic type declarations from aesara.link.basic to aesara.graph.op

上级 2fbc3bb6
......@@ -36,13 +36,16 @@ if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
StorageMapType = Dict[Variable, List[Optional[List[Any]]]]
StorageCellType = List[Optional[Any]]
StorageMapType = Dict[Variable, StorageCellType]
ComputeMapType = Dict[Variable, List[bool]]
OutputStorageType = List[Optional[List[Any]]]
InputStorageType = List[StorageCellType]
OutputStorageType = List[StorageCellType]
ParamsInputType = Optional[Tuple[Any]]
PerformMethodType = Callable[
[Apply, List[Any], OutputStorageType, ParamsInputType], None
]
BasicThunkType = Callable[[], None]
ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
]
......@@ -470,8 +473,8 @@ class Op(MetaObject):
def prepare_node(
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
storage_map: Optional[StorageMapType],
compute_map: Optional[ComputeMapType],
impl: Optional[Text],
) -> None:
"""Make any special modifications that the `Op` needs before doing :meth:`Op.make_thunk`.
......
......@@ -24,12 +24,16 @@ from aesara.utils import difference
if TYPE_CHECKING:
from aesara.compile.profiling import ProfileStats
from aesara.graph.op import (
BasicThunkType,
InputStorageType,
OutputStorageType,
StorageMapType,
)
from aesara.tensor.var import TensorVariable
StorageMapType = Dict[Variable, List[Optional[Union[ndarray, slice]]]]
OutputStorageType = List[Optional[List[Any]]]
InputStorageType = OutputStorageType
ThunkType = Tuple[Callable[[], None], List["Container"], List["Container"]]
ThunkAndContainersType = Tuple["BasicThunkType", List["Container"], List["Container"]]
class Container:
......@@ -192,7 +196,7 @@ class Linker(ABC):
@abstractmethod
def make_thunk(
self, **kwargs
) -> Tuple[Callable, InputStorageType, OutputStorageType]:
) -> Tuple[Callable, "InputStorageType", "OutputStorageType"]:
"""
This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the
......@@ -242,11 +246,11 @@ class LocalLinker(Linker):
def make_thunk(
self,
input_storage: Optional[InputStorageType] = None,
output_storage: Optional[OutputStorageType] = None,
storage_map: Optional[StorageMapType] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
**kwargs,
) -> Tuple[Callable[[], None], InputStorageType, OutputStorageType]:
) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
return self.make_all(
input_storage=input_storage,
output_storage=output_storage,
......@@ -255,14 +259,14 @@ class LocalLinker(Linker):
def make_all(
self,
input_storage: Optional[InputStorageType],
output_storage: Optional[OutputStorageType],
storage_map: Optional[StorageMapType],
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
) -> Tuple[
Callable[[], None],
InputStorageType,
OutputStorageType,
List[ThunkType],
"BasicThunkType",
"InputStorageType",
"OutputStorageType",
List[ThunkAndContainersType],
List[Apply],
]:
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论