提交 f9618a62 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.basic

上级 90c4ce82
......@@ -4,15 +4,17 @@ from collections import deque
from copy import copy
from itertools import count
from typing import (
TYPE_CHECKING,
Callable,
Collection,
Deque,
Dict,
Generator,
Hashable,
Iterable,
Iterator,
List,
Optional,
Reversible,
Sequence,
Set,
Tuple,
......@@ -36,9 +38,13 @@ from aesara.graph.utils import (
from aesara.misc.ordered_set import OrderedSet
T = TypeVar("T")
if TYPE_CHECKING:
from aesara.graph.type import Type
T = TypeVar("T", bound="Node")
NoParams = object()
NodeAndChildren = Tuple[T, Optional[Iterable[T]]]
class Node(MetaObject):
......@@ -49,6 +55,8 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
type: "Type"
name: Optional[str]
def get_parents(self):
"""
......@@ -377,7 +385,23 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None):
_owner: Optional[Apply]
@property
def owner(self) -> Optional[Apply]:
return self._owner
@owner.setter
def owner(self, value) -> None:
self._owner = value
def __init__(
self,
type,
owner: Optional[Apply] = None,
index: Optional[int] = None,
name: Optional[str] = None,
):
super().__init__()
self.tag = ValidatingScratchpad("test_value", type.filter)
......@@ -386,7 +410,7 @@ class Variable(Node):
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
self._owner = owner
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
......@@ -607,19 +631,15 @@ class Constant(Variable):
"""Return `self`, because there's no reason to clone a constant."""
return self
def __set_owner(self, value):
"""Prevent the :prop:`owner` property from being set.
Raises
------
ValueError
If `value` is not ``None``.
@property
def owner(self) -> None:
return None
"""
@owner.setter
def owner(self, value) -> None:
if value is not None:
raise ValueError("Constant instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data, doc="read-only data access method")
# index is not defined, because the `owner` attribute must necessarily be None
......@@ -627,34 +647,30 @@ class Constant(Variable):
def walk(
nodes: Iterable[T],
expand: Callable[[T], Optional[Sequence[T]]],
expand: Callable[[T], Optional[Iterable[T]]],
bfs: bool = True,
return_children: bool = False,
hash_fn: Callable[[T], Hashable] = id,
) -> Generator[Union[T, Tuple[T, Optional[Sequence[T]]]], None, None]:
"""Walk through a graph, either breadth- or depth-first.
hash_fn: Callable[[T], int] = id,
) -> Generator[Union[T, NodeAndChildren], None, None]:
r"""Walk through a graph, either breadth- or depth-first.
Parameters
----------
nodes : deque
nodes
The nodes from which to start walking.
expand : callable
expand
A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``.
bfs : bool
bfs
If ``True``, breath first search is used; otherwise, depth first
search.
return_children : bool
return_children
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
hash_fn : callable
hash_fn
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Yields
------
nodes
Notes
-----
A node will appear at most once in the return value, even if it
......@@ -664,7 +680,7 @@ def walk(
nodes = deque(nodes)
rval_set: Set[T] = set()
rval_set: Set[int] = set()
nodes_pop: Callable[[], T]
if bfs:
......@@ -675,13 +691,13 @@ def walk(
while nodes:
node: T = nodes_pop()
node_hash: Hashable = hash_fn(node)
node_hash: int = hash_fn(node)
if node_hash not in rval_set:
rval_set.add(node_hash)
new_nodes: Optional[Sequence[T]] = expand(node)
new_nodes: Optional[Iterable[T]] = expand(node)
if return_children:
yield node, new_nodes
......@@ -693,7 +709,7 @@ def walk(
def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] = None
graphs: Iterable[Variable], blockers: Optional[Collection[Variable]] = None
) -> Generator[Variable, None, None]:
r"""Return the variables that contribute to those in given graphs (inclusive).
......@@ -714,15 +730,15 @@ def ancestors(
"""
def expand(r):
def expand(r: Variable) -> Optional[Iterator[Variable]]:
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
yield from walk(graphs, expand, False)
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] = None
graphs: Iterable[Variable], blockers: Optional[Collection[Variable]] = None
) -> Generator[Variable, None, None]:
r"""Return the inputs required to compute the given Variables.
......@@ -751,26 +767,25 @@ def vars_between(
Parameters
----------
ins : list
ins
Input `Variable`\s.
outs : list
outs
Output `Variable`\s.
Yields
------
`Variable`\s
The `Variable`\s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans_between(ins, outs)`` and all values of all intermediary steps from
`ins` to `outs`.
The `Variable`\s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans_between(ins, outs)`` and all values of all intermediary steps from
`ins` to `outs`.
"""
def expand(r):
def expand(r: Variable) -> Optional[Iterable[Variable]]:
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
yield from walk(outs, expand)
yield from cast(Generator[Variable, None, None], walk(outs, expand))
def orphans_between(
......@@ -862,16 +877,18 @@ def clone(
if copy_orphans is None:
copy_orphans = copy_inputs
equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans)
return [equiv[input] for input in inputs], [equiv[output] for output in outputs]
return [cast(Variable, equiv[input]) for input in inputs], [
cast(Variable, equiv[output]) for output in outputs
]
def clone_get_equiv(
inputs: List[Variable],
outputs: List[Variable],
inputs: Sequence[Variable],
outputs: Sequence[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None,
) -> Dict[Variable, Variable]:
memo: Optional[Dict[Node, Node]] = None,
) -> Dict[Node, Node]:
"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
......@@ -934,11 +951,13 @@ def clone_get_equiv(
def clone_replace(
output: Collection[Variable],
replace: Optional[Dict[Variable, Variable]] = None,
output: List[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
strict: bool = True,
share_inputs: bool = True,
) -> Collection[Variable]:
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
......@@ -959,6 +978,7 @@ def clone_replace(
"""
from aesara.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
......@@ -984,13 +1004,15 @@ def clone_replace(
_outs, [], new_replace, [], strict, share_inputs
)
return outs
return cast(List[Variable], outs)
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]],
compute_deps_cache: Optional[Callable[[T], Union[OrderedSet, List[T]]]] = None,
compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, List[T]]]]
] = None,
deps_cache: Optional[Dict[T, List[T]]] = None,
clients: Optional[Dict[T, List[T]]] = None,
) -> List[T]:
......@@ -1030,7 +1052,7 @@ def general_toposort(
if deps_cache is None:
deps_cache = {}
def compute_deps_cache(io):
def _compute_deps_cache(io):
if io not in deps_cache:
d = deps(io)
......@@ -1048,23 +1070,27 @@ def general_toposort(
else:
return deps_cache[io]
else:
_compute_deps_cache = compute_deps_cache
if deps_cache is None:
raise ValueError("deps_cache cannot be None")
search_res: List[Tuple[T, Optional[List[T]]]] = list(
walk(outputs, compute_deps_cache, bfs=False, return_children=True)
search_res: List[NodeAndChildren] = cast(
List[NodeAndChildren],
list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)),
)
_clients: Dict[T, List[T]] = {}
sources: Deque[T] = deque()
search_res_len: int = 0
for node, children in search_res:
for snode, children in search_res:
search_res_len += 1
if children:
for child in children:
_clients.setdefault(child, []).append(node)
if not deps_cache.get(node):
sources.append(node)
_clients.setdefault(child, []).append(snode)
if not deps_cache.get(snode):
sources.append(snode)
if clients is not None:
clients.update(_clients)
......@@ -1090,7 +1116,7 @@ def general_toposort(
def io_toposort(
inputs: Iterable[Variable],
outputs: Iterable[Variable],
outputs: Reversible[Variable],
orderings: Optional[Dict[Apply, List[Apply]]] = None,
clients: Optional[Dict[Variable, List[Variable]]] = None,
) -> List[Apply]:
......@@ -1301,8 +1327,8 @@ def as_string(
orph = list(orphans_between(i, outputs))
multi: Set = set()
seen: Set = set()
multi = set()
seen = set()
for output in outputs:
op = output.owner
if op in seen:
......@@ -1318,11 +1344,11 @@ def as_string(
multi.add(op2)
else:
seen.add(input.owner)
multi: Set = [x for x in multi]
multi_list = [x for x in multi]
done: Set = set()
def multi_index(x):
return multi.index(x) + 1
return multi_list.index(x) + 1
def describe(r):
if r.owner is not None and r not in i and r not in orph:
......@@ -1337,7 +1363,7 @@ def as_string(
else:
done.add(op)
s = node_formatter(op, [describe(input) for input in op.inputs])
if op in multi:
if op in multi_list:
return f"*{multi_index(op)} -> {s}"
else:
return s
......@@ -1380,14 +1406,21 @@ def list_of_nodes(
Output `Variable`\s.
"""
def expand(o: Apply) -> List[Apply]:
return [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
]
return list(
walk(
[o.owner for o in outputs],
lambda o: [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
],
cast(
Iterable[Apply],
walk(
[o.owner for o in outputs if o.owner],
expand,
),
)
)
......@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False
def equal_computations(xs, ys, in_xs=None, in_ys=None):
def equal_computations(
xs: List[Union[np.ndarray, Variable]],
ys: List[Union[np.ndarray, Variable]],
in_xs: Optional[List[Variable]] = None,
in_ys: Optional[List[Variable]] = None,
) -> bool:
"""Checks if Aesara graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
......@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
return cast(bool, np.array_equal(x, y))
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return cast(bool, np.array_equal(y.data, x))
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return False
if not isinstance(y, Variable):
return cast(bool, np.array_equal(x.data, y))
return False
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
return False
if x.owner: # Check above tell that y.owner eval to True too.
if x.owner and y.owner:
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False
if x not in in_xs and not (y.type.in_same_class(x.type)):
......@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False
common = set(zip(in_xs, in_ys))
different = set()
different: Set[Tuple[Variable, Variable]] = set()
for dx, dy in zip(xs, ys):
assert isinstance(dx, Variable)
# We checked above that both dx and dy have an owner or not
if not dx.owner:
if isinstance(dx, Constant) and isinstance(dy, Constant):
......@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)):
if xs[i].owner:
x_i: Variable = cast(Variable, xs[i])
if x_i.owner:
y_i: Variable = cast(Variable, ys[i])
# The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different)
is_equal = compare_nodes(x_i.owner, y_i.owner, common, different)
if not is_equal:
return False
......@@ -1605,7 +1644,7 @@ def get_var_by_name(
"""
from aesara.graph.op import HasInnerGraph
def expand(r):
def expand(r) -> Optional[List[Variable]]:
if r.owner:
res = list(r.owner.inputs)
......
......@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
def simple_extract_stack(
f=None, limit: Optional[int] = None, skips: Optional[Sequence[str]] = None
) -> List[Tuple[Optional[str], int, str, str]]:
) -> List[Tuple[Optional[str], int, str, Optional[str]]]:
"""This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache.
......@@ -28,14 +28,12 @@ def simple_extract_stack(
skips = []
if f is None:
try:
raise ZeroDivisionError
except ZeroDivisionError:
f = sys.exc_info()[2].tb_frame.f_back
f = sys._getframe().f_back
if limit is None:
if hasattr(sys, "tracebacklimit"):
limit = sys.tracebacklimit
trace: List[Tuple[Optional[str], int, str, str]] = []
trace: List[Tuple[Optional[str], int, str, Optional[str]]] = []
n = 0
while f is not None and (limit is None or n < limit):
lineno = f.f_lineno
......@@ -43,7 +41,7 @@ def simple_extract_stack(
filename = co.co_filename
name = co.co_name
# linecache.checkcache(filename)
line = linecache.getline(filename, lineno, f.f_globals)
line: Optional[str] = linecache.getline(filename, lineno, f.f_globals)
if line:
line = line.strip()
else:
......
......@@ -115,14 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.op]
ignore_errors = True
check_untyped_defs = False
......@@ -179,6 +171,10 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.builders]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.sharedvalue]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论