提交 41c10974 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.link.utils

上级 b4312087
...@@ -9,8 +9,22 @@ from keyword import iskeyword ...@@ -9,8 +9,22 @@ from keyword import iskeyword
from operator import itemgetter from operator import itemgetter
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from textwrap import indent from textwrap import indent
from types import FunctionType from typing import (
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Optional,
Sequence,
TextIO,
Tuple,
TypeVar,
Union,
cast,
)
import numpy as np import numpy as np
...@@ -19,13 +33,23 @@ from aesara.graph.basic import Apply, Constant, Variable ...@@ -19,13 +33,23 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
if TYPE_CHECKING:
from aesara.graph.op import (
BasicThunkType,
InputStorageType,
OutputStorageType,
StorageCellType,
StorageMapType,
)
def map_storage( def map_storage(
fgraph: FunctionGraph, fgraph: FunctionGraph,
order: Iterable[Apply], order: Iterable[Apply],
input_storage: Optional[List], input_storage: Optional["InputStorageType"] = None,
output_storage: Optional[List], output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional[Dict] = None, storage_map: Optional["StorageMapType"] = None,
) -> Tuple[List, List, Dict]: ) -> Tuple["InputStorageType", "OutputStorageType", "StorageMapType"]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes. """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
Parameters Parameters
...@@ -125,14 +149,13 @@ def map_storage( ...@@ -125,14 +149,13 @@ def map_storage(
def streamline( def streamline(
fgraph: FunctionGraph, fgraph: FunctionGraph,
thunks, thunks: Sequence[Callable[[], None]],
order, order: Sequence[Apply],
post_thunk_old_storage=None, post_thunk_old_storage: Optional[List["StorageCellType"]] = None,
no_recycling=None, no_recycling: Optional[List["StorageCellType"]] = None,
nice_errors=True, nice_errors: bool = True,
) -> Callable[[], None]: ) -> "BasicThunkType":
""" """Construct a single thunk that runs a list of thunks.
WRITEME
Parameters Parameters
---------- ----------
...@@ -246,7 +269,7 @@ def gc_helper(node_list: List[Apply]): ...@@ -246,7 +269,7 @@ def gc_helper(node_list: List[Apply]):
def raise_with_op( def raise_with_op(
fgraph: FunctionGraph, node: Apply, thunk=None, exc_info=None, storage_map=None fgraph: FunctionGraph, node: Apply, thunk=None, exc_info=None, storage_map=None
): ) -> NoReturn:
""" """
Re-raise an exception while annotating the exception object with Re-raise an exception while annotating the exception object with
debug info. debug info.
...@@ -293,13 +316,11 @@ def raise_with_op( ...@@ -293,13 +316,11 @@ def raise_with_op(
if exc_type == KeyboardInterrupt: if exc_type == KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt # print a simple traceback from KeyboardInterrupt
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
try:
trace = node.outputs[0].tag.trace trace = getattr(node.outputs[0].tag, "trace", ())
except AttributeError: if not trace and hasattr(node.op, "tag"):
try: trace = getattr(node.op.tag, "trace", ())
trace = node.op.tag.trace
except AttributeError:
trace = ()
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node exc_value.__op_instance__ = node
topo = fgraph.toposort() topo = fgraph.toposort()
...@@ -310,9 +331,9 @@ def raise_with_op( ...@@ -310,9 +331,9 @@ def raise_with_op(
exc_value.__applynode_index__ = node_index exc_value.__applynode_index__ = node_index
hints = [] hints = []
detailed_err_msg = "\nApply node that caused the error: " + str(node) detailed_err_msg = f"\nApply node that caused the error: {node}"
if exc_value.__applynode_index__ is not None: if exc_value.__applynode_index__ is not None:
detailed_err_msg += f"\nToposort index: {int(node_index)}" detailed_err_msg += f"\nToposort index: {node_index}"
types = [getattr(ipt, "type", "No type") for ipt in node.inputs] types = [getattr(ipt, "type", "No type") for ipt in node.inputs]
detailed_err_msg += f"\nInputs types: {types}\n" detailed_err_msg += f"\nInputs types: {types}\n"
...@@ -517,7 +538,7 @@ def raise_with_op( ...@@ -517,7 +538,7 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
def __log_thunk_trace(value, handler: io.TextIOWrapper): def __log_thunk_trace(value, handler: io.TextIOWrapper) -> None:
""" """
Log Aesara's diagnostic stack trace for an exception. Log Aesara's diagnostic stack trace for an exception.
...@@ -550,12 +571,12 @@ def __log_thunk_trace(value, handler: io.TextIOWrapper): ...@@ -550,12 +571,12 @@ def __log_thunk_trace(value, handler: io.TextIOWrapper):
) )
def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout): def register_thunk_trace_excepthook(handler: TextIO = sys.stdout) -> None:
"""Adds the __log_thunk_trace except hook to the collection in aesara.utils. """Adds the `__log_thunk_trace` except hook to the collection in `aesara.utils`.
Parameters Parameters
---------- ----------
handler : TextIOWrapper handler
Target for printing the output. Target for printing the output.
""" """
...@@ -568,7 +589,12 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout): ...@@ -568,7 +589,12 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook() register_thunk_trace_excepthook()
def compile_function_src(src, function_name, global_env=None, local_env=None): def compile_function_src(
src: str,
function_name: str,
global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None,
) -> Callable:
with NamedTemporaryFile(delete=False) as f: with NamedTemporaryFile(delete=False) as f:
filename = f.name filename = f.name
...@@ -583,12 +609,12 @@ def compile_function_src(src, function_name, global_env=None, local_env=None): ...@@ -583,12 +609,12 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
mod_code = compile(src, filename, mode="exec") mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env) exec(mod_code, global_env, local_env)
res = local_env[function_name] res = cast(Callable, local_env[function_name])
res.__source__ = src res.__source__ = src # type: ignore
return res return res
def get_name_for_object(x: Any): def get_name_for_object(x: Any) -> str:
"""Get the name for an arbitrary object.""" """Get the name for an arbitrary object."""
if isinstance(x, Variable): if isinstance(x, Variable):
...@@ -603,7 +629,7 @@ def get_name_for_object(x: Any): ...@@ -603,7 +629,7 @@ def get_name_for_object(x: Any):
else x.auto_name else x.auto_name
) )
else: else:
name = getattr(x, "__name__", None) name = getattr(x, "__name__", "")
if not name or (not name.isidentifier() or iskeyword(name)): if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__ name = type(x).__name__
...@@ -619,27 +645,31 @@ def unique_name_generator( ...@@ -619,27 +645,31 @@ def unique_name_generator(
if external_names is None: if external_names is None:
external_names = [] external_names = []
def unique_name(x, force_unique=False): T = TypeVar("T")
if not force_unique and x in unique_name.obj_to_names:
return unique_name.obj_to_names[x] def unique_name(
x: T,
force_unique=False,
names_counter=Counter(external_names),
objs_to_names: Dict[T, str] = {},
) -> str:
if not force_unique and x in objs_to_names:
return objs_to_names[x]
name = get_name_for_object(x) name = get_name_for_object(x)
name_suffix = unique_name.names_counter.get(name, "") name_suffix = names_counter.get(name, "")
if name_suffix: if name_suffix:
local_name = f"{name}{suffix_sep}{name_suffix}" local_name = f"{name}{suffix_sep}{name_suffix}"
unique_name.names_counter.update((name,)) names_counter.update((name,))
else: else:
local_name = name local_name = name
unique_name.names_counter.update((local_name,)) names_counter.update((local_name,))
unique_name.obj_to_names[x] = local_name objs_to_names[x] = local_name
return local_name return local_name
unique_name.names_counter = Counter(external_names)
unique_name.obj_to_names = {}
return unique_name return unique_name
...@@ -647,18 +677,18 @@ def fgraph_to_python( ...@@ -647,18 +677,18 @@ def fgraph_to_python(
fgraph: FunctionGraph, fgraph: FunctionGraph,
op_conversion_fn: Callable, op_conversion_fn: Callable,
*, *,
type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x, type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Variable]] = None, order: Optional[List[Apply]] = None,
input_storage: Optional[List[Any]] = None, input_storage: Optional["InputStorageType"] = None,
output_storage: Optional[List[Any]] = None, output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional[Dict[Variable, List[Any]]] = None, storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python", fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None, global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None, local_env: Optional[Dict[Any, Any]] = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object, get_name_for_object: Callable[[Any], str] = get_name_for_object,
squeeze_output: bool = False, squeeze_output: bool = False,
**kwargs, **kwargs,
) -> FunctionType: ) -> Callable:
"""Convert a ``FunctionGraph`` into a regular Python function. """Convert a ``FunctionGraph`` into a regular Python function.
Parameters Parameters
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.link.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.c.cmodule] [mypy-aesara.link.c.cmodule]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论