提交 41c10974 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.link.utils

上级 b4312087
......@@ -9,8 +9,22 @@ from keyword import iskeyword
from operator import itemgetter
from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Optional,
Sequence,
TextIO,
Tuple,
TypeVar,
Union,
cast,
)
import numpy as np
......@@ -19,13 +33,23 @@ from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
if TYPE_CHECKING:
from aesara.graph.op import (
BasicThunkType,
InputStorageType,
OutputStorageType,
StorageCellType,
StorageMapType,
)
def map_storage(
fgraph: FunctionGraph,
order: Iterable[Apply],
input_storage: Optional[List],
output_storage: Optional[List],
storage_map: Optional[Dict] = None,
) -> Tuple[List, List, Dict]:
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
) -> Tuple["InputStorageType", "OutputStorageType", "StorageMapType"]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
Parameters
......@@ -125,14 +149,13 @@ def map_storage(
def streamline(
fgraph: FunctionGraph,
thunks,
order,
post_thunk_old_storage=None,
no_recycling=None,
nice_errors=True,
) -> Callable[[], None]:
"""
WRITEME
thunks: Sequence[Callable[[], None]],
order: Sequence[Apply],
post_thunk_old_storage: Optional[List["StorageCellType"]] = None,
no_recycling: Optional[List["StorageCellType"]] = None,
nice_errors: bool = True,
) -> "BasicThunkType":
"""Construct a single thunk that runs a list of thunks.
Parameters
----------
......@@ -246,7 +269,7 @@ def gc_helper(node_list: List[Apply]):
def raise_with_op(
fgraph: FunctionGraph, node: Apply, thunk=None, exc_info=None, storage_map=None
):
) -> NoReturn:
"""
Re-raise an exception while annotating the exception object with
debug info.
......@@ -293,13 +316,11 @@ def raise_with_op(
if exc_type == KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt
raise exc_value.with_traceback(exc_trace)
try:
trace = node.outputs[0].tag.trace
except AttributeError:
try:
trace = node.op.tag.trace
except AttributeError:
trace = ()
trace = getattr(node.outputs[0].tag, "trace", ())
if not trace and hasattr(node.op, "tag"):
trace = getattr(node.op.tag, "trace", ())
exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node
topo = fgraph.toposort()
......@@ -310,9 +331,9 @@ def raise_with_op(
exc_value.__applynode_index__ = node_index
hints = []
detailed_err_msg = "\nApply node that caused the error: " + str(node)
detailed_err_msg = f"\nApply node that caused the error: {node}"
if exc_value.__applynode_index__ is not None:
detailed_err_msg += f"\nToposort index: {int(node_index)}"
detailed_err_msg += f"\nToposort index: {node_index}"
types = [getattr(ipt, "type", "No type") for ipt in node.inputs]
detailed_err_msg += f"\nInputs types: {types}\n"
......@@ -517,7 +538,7 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace)
def __log_thunk_trace(value, handler: io.TextIOWrapper):
def __log_thunk_trace(value, handler: io.TextIOWrapper) -> None:
"""
Log Aesara's diagnostic stack trace for an exception.
......@@ -550,12 +571,12 @@ def __log_thunk_trace(value, handler: io.TextIOWrapper):
)
def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
"""Adds the __log_thunk_trace except hook to the collection in aesara.utils.
def register_thunk_trace_excepthook(handler: TextIO = sys.stdout) -> None:
"""Adds the `__log_thunk_trace` except hook to the collection in `aesara.utils`.
Parameters
----------
handler : TextIOWrapper
handler
Target for printing the output.
"""
......@@ -568,7 +589,12 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook()
def compile_function_src(src, function_name, global_env=None, local_env=None):
def compile_function_src(
src: str,
function_name: str,
global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None,
) -> Callable:
with NamedTemporaryFile(delete=False) as f:
filename = f.name
......@@ -583,12 +609,12 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
mod_code = compile(src, filename, mode="exec")
exec(mod_code, global_env, local_env)
res = local_env[function_name]
res.__source__ = src
res = cast(Callable, local_env[function_name])
res.__source__ = src # type: ignore
return res
def get_name_for_object(x: Any):
def get_name_for_object(x: Any) -> str:
"""Get the name for an arbitrary object."""
if isinstance(x, Variable):
......@@ -603,7 +629,7 @@ def get_name_for_object(x: Any):
else x.auto_name
)
else:
name = getattr(x, "__name__", None)
name = getattr(x, "__name__", "")
if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__
......@@ -619,27 +645,31 @@ def unique_name_generator(
if external_names is None:
external_names = []
def unique_name(x, force_unique=False):
if not force_unique and x in unique_name.obj_to_names:
return unique_name.obj_to_names[x]
T = TypeVar("T")
def unique_name(
x: T,
force_unique=False,
names_counter=Counter(external_names),
objs_to_names: Dict[T, str] = {},
) -> str:
if not force_unique and x in objs_to_names:
return objs_to_names[x]
name = get_name_for_object(x)
name_suffix = unique_name.names_counter.get(name, "")
name_suffix = names_counter.get(name, "")
if name_suffix:
local_name = f"{name}{suffix_sep}{name_suffix}"
unique_name.names_counter.update((name,))
names_counter.update((name,))
else:
local_name = name
unique_name.names_counter.update((local_name,))
unique_name.obj_to_names[x] = local_name
names_counter.update((local_name,))
objs_to_names[x] = local_name
return local_name
unique_name.names_counter = Counter(external_names)
unique_name.obj_to_names = {}
return unique_name
......@@ -647,18 +677,18 @@ def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
*,
type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x,
order: Optional[List[Variable]] = None,
input_storage: Optional[List[Any]] = None,
output_storage: Optional[List[Any]] = None,
storage_map: Optional[Dict[Variable, List[Any]]] = None,
type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Apply]] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object,
squeeze_output: bool = False,
**kwargs,
) -> FunctionType:
) -> Callable:
"""Convert a ``FunctionGraph`` into a regular Python function.
Parameters
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.c.cmodule]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论