提交 f9618a62 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.basic

上级 90c4ce82
...@@ -4,15 +4,17 @@ from collections import deque ...@@ -4,15 +4,17 @@ from collections import deque
from copy import copy from copy import copy
from itertools import count from itertools import count
from typing import ( from typing import (
TYPE_CHECKING,
Callable, Callable,
Collection, Collection,
Deque, Deque,
Dict, Dict,
Generator, Generator,
Hashable,
Iterable, Iterable,
Iterator,
List, List,
Optional, Optional,
Reversible,
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
...@@ -36,9 +38,13 @@ from aesara.graph.utils import ( ...@@ -36,9 +38,13 @@ from aesara.graph.utils import (
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
T = TypeVar("T") if TYPE_CHECKING:
from aesara.graph.type import Type
T = TypeVar("T", bound="Node")
NoParams = object() NoParams = object()
NodeAndChildren = Tuple[T, Optional[Iterable[T]]]
class Node(MetaObject): class Node(MetaObject):
...@@ -49,6 +55,8 @@ class Node(MetaObject): ...@@ -49,6 +55,8 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`. keeps track of its parents via `Variable.owner` / `Apply.inputs`.
""" """
type: "Type"
name: Optional[str]
def get_parents(self): def get_parents(self):
""" """
...@@ -377,7 +385,23 @@ class Variable(Node): ...@@ -377,7 +385,23 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name'] # __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0) __count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None): _owner: Optional[Apply]
@property
def owner(self) -> Optional[Apply]:
return self._owner
@owner.setter
def owner(self, value) -> None:
self._owner = value
def __init__(
self,
type,
owner: Optional[Apply] = None,
index: Optional[int] = None,
name: Optional[str] = None,
):
super().__init__() super().__init__()
self.tag = ValidatingScratchpad("test_value", type.filter) self.tag = ValidatingScratchpad("test_value", type.filter)
...@@ -386,7 +410,7 @@ class Variable(Node): ...@@ -386,7 +410,7 @@ class Variable(Node):
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
self.owner = owner self._owner = owner
if index is not None and not isinstance(index, int): if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index) raise TypeError("index must be an int", index)
...@@ -607,19 +631,15 @@ class Constant(Variable): ...@@ -607,19 +631,15 @@ class Constant(Variable):
"""Return `self`, because there's no reason to clone a constant.""" """Return `self`, because there's no reason to clone a constant."""
return self return self
def __set_owner(self, value): @property
"""Prevent the :prop:`owner` property from being set. def owner(self) -> None:
return None
Raises
------
ValueError
If `value` is not ``None``.
""" @owner.setter
def owner(self, value) -> None:
if value is not None: if value is not None:
raise ValueError("Constant instances cannot have an owner.") raise ValueError("Constant instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data, doc="read-only data access method") value = property(lambda self: self.data, doc="read-only data access method")
# index is not defined, because the `owner` attribute must necessarily be None # index is not defined, because the `owner` attribute must necessarily be None
...@@ -627,34 +647,30 @@ class Constant(Variable): ...@@ -627,34 +647,30 @@ class Constant(Variable):
def walk( def walk(
nodes: Iterable[T], nodes: Iterable[T],
expand: Callable[[T], Optional[Sequence[T]]], expand: Callable[[T], Optional[Iterable[T]]],
bfs: bool = True, bfs: bool = True,
return_children: bool = False, return_children: bool = False,
hash_fn: Callable[[T], Hashable] = id, hash_fn: Callable[[T], int] = id,
) -> Generator[Union[T, Tuple[T, Optional[Sequence[T]]]], None, None]: ) -> Generator[Union[T, NodeAndChildren], None, None]:
"""Walk through a graph, either breadth- or depth-first. r"""Walk through a graph, either breadth- or depth-first.
Parameters Parameters
---------- ----------
nodes : deque nodes
The nodes from which to start walking. The nodes from which to start walking.
expand : callable expand
A callable that is applied to each node in `nodes`, the results of A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``. which are either new nodes to visit or ``None``.
bfs : bool bfs
If ``True``, breath first search is used; otherwise, depth first If ``True``, breath first search is used; otherwise, depth first
search. search.
return_children : bool return_children
If ``True``, each output node will be accompanied by the output of If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes). `expand` (i.e. the corresponding child nodes).
hash_fn : callable hash_fn
The function used to produce hashes of the elements in `nodes`. The function used to produce hashes of the elements in `nodes`.
The default is ``id``. The default is ``id``.
Yields
------
nodes
Notes Notes
----- -----
A node will appear at most once in the return value, even if it A node will appear at most once in the return value, even if it
...@@ -664,7 +680,7 @@ def walk( ...@@ -664,7 +680,7 @@ def walk(
nodes = deque(nodes) nodes = deque(nodes)
rval_set: Set[T] = set() rval_set: Set[int] = set()
nodes_pop: Callable[[], T] nodes_pop: Callable[[], T]
if bfs: if bfs:
...@@ -675,13 +691,13 @@ def walk( ...@@ -675,13 +691,13 @@ def walk(
while nodes: while nodes:
node: T = nodes_pop() node: T = nodes_pop()
node_hash: Hashable = hash_fn(node) node_hash: int = hash_fn(node)
if node_hash not in rval_set: if node_hash not in rval_set:
rval_set.add(node_hash) rval_set.add(node_hash)
new_nodes: Optional[Sequence[T]] = expand(node) new_nodes: Optional[Iterable[T]] = expand(node)
if return_children: if return_children:
yield node, new_nodes yield node, new_nodes
...@@ -693,7 +709,7 @@ def walk( ...@@ -693,7 +709,7 @@ def walk(
def ancestors( def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] = None graphs: Iterable[Variable], blockers: Optional[Collection[Variable]] = None
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Return the variables that contribute to those in given graphs (inclusive). r"""Return the variables that contribute to those in given graphs (inclusive).
...@@ -714,15 +730,15 @@ def ancestors( ...@@ -714,15 +730,15 @@ def ancestors(
""" """
def expand(r): def expand(r: Variable) -> Optional[Iterator[Variable]]:
if r.owner and (not blockers or r not in blockers): if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs) return reversed(r.owner.inputs)
yield from walk(graphs, expand, False) yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
def graph_inputs( def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] = None graphs: Iterable[Variable], blockers: Optional[Collection[Variable]] = None
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
r"""Return the inputs required to compute the given Variables. r"""Return the inputs required to compute the given Variables.
...@@ -751,26 +767,25 @@ def vars_between( ...@@ -751,26 +767,25 @@ def vars_between(
Parameters Parameters
---------- ----------
ins : list ins
Input `Variable`\s. Input `Variable`\s.
outs : list outs
Output `Variable`\s. Output `Variable`\s.
Yields Yields
------ ------
`Variable`\s The `Variable`\s that are involved in the subgraph that lies
The `Variable`\s that are involved in the subgraph that lies between `ins` and `outs`. This includes `ins`, `outs`,
between `ins` and `outs`. This includes `ins`, `outs`, ``orphans_between(ins, outs)`` and all values of all intermediary steps from
``orphans_between(ins, outs)`` and all values of all intermediary steps from `ins` to `outs`.
`ins` to `outs`.
""" """
def expand(r): def expand(r: Variable) -> Optional[Iterable[Variable]]:
if r.owner and r not in ins: if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs) return reversed(r.owner.inputs + r.owner.outputs)
yield from walk(outs, expand) yield from cast(Generator[Variable, None, None], walk(outs, expand))
def orphans_between( def orphans_between(
...@@ -862,16 +877,18 @@ def clone( ...@@ -862,16 +877,18 @@ def clone(
if copy_orphans is None: if copy_orphans is None:
copy_orphans = copy_inputs copy_orphans = copy_inputs
equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans) equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans)
return [equiv[input] for input in inputs], [equiv[output] for output in outputs] return [cast(Variable, equiv[input]) for input in inputs], [
cast(Variable, equiv[output]) for output in outputs
]
def clone_get_equiv( def clone_get_equiv(
inputs: List[Variable], inputs: Sequence[Variable],
outputs: List[Variable], outputs: Sequence[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict[Node, Node]] = None,
) -> Dict[Variable, Variable]: ) -> Dict[Node, Node]:
""" """
Return a dictionary that maps from `Variable` and `Apply` nodes in the Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
...@@ -934,11 +951,13 @@ def clone_get_equiv( ...@@ -934,11 +951,13 @@ def clone_get_equiv(
def clone_replace( def clone_replace(
output: Collection[Variable], output: List[Variable],
replace: Optional[Dict[Variable, Variable]] = None, replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
strict: bool = True, strict: bool = True,
share_inputs: bool = True, share_inputs: bool = True,
) -> Collection[Variable]: ) -> List[Variable]:
"""Clone a graph and replace subgraphs within it. """Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding It returns a copy of the initial subgraph with the corresponding
...@@ -959,6 +978,7 @@ def clone_replace( ...@@ -959,6 +978,7 @@ def clone_replace(
""" """
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict): if isinstance(replace, dict):
items = list(replace.items()) items = list(replace.items())
elif isinstance(replace, (list, tuple)): elif isinstance(replace, (list, tuple)):
...@@ -984,13 +1004,15 @@ def clone_replace( ...@@ -984,13 +1004,15 @@ def clone_replace(
_outs, [], new_replace, [], strict, share_inputs _outs, [], new_replace, [], strict, share_inputs
) )
return outs return cast(List[Variable], outs)
def general_toposort( def general_toposort(
outputs: Iterable[T], outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]], deps: Callable[[T], Union[OrderedSet, List[T]]],
compute_deps_cache: Optional[Callable[[T], Union[OrderedSet, List[T]]]] = None, compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, List[T]]]]
] = None,
deps_cache: Optional[Dict[T, List[T]]] = None, deps_cache: Optional[Dict[T, List[T]]] = None,
clients: Optional[Dict[T, List[T]]] = None, clients: Optional[Dict[T, List[T]]] = None,
) -> List[T]: ) -> List[T]:
...@@ -1030,7 +1052,7 @@ def general_toposort( ...@@ -1030,7 +1052,7 @@ def general_toposort(
if deps_cache is None: if deps_cache is None:
deps_cache = {} deps_cache = {}
def compute_deps_cache(io): def _compute_deps_cache(io):
if io not in deps_cache: if io not in deps_cache:
d = deps(io) d = deps(io)
...@@ -1048,23 +1070,27 @@ def general_toposort( ...@@ -1048,23 +1070,27 @@ def general_toposort(
else: else:
return deps_cache[io] return deps_cache[io]
else:
_compute_deps_cache = compute_deps_cache
if deps_cache is None: if deps_cache is None:
raise ValueError("deps_cache cannot be None") raise ValueError("deps_cache cannot be None")
search_res: List[Tuple[T, Optional[List[T]]]] = list( search_res: List[NodeAndChildren] = cast(
walk(outputs, compute_deps_cache, bfs=False, return_children=True) List[NodeAndChildren],
list(walk(outputs, _compute_deps_cache, bfs=False, return_children=True)),
) )
_clients: Dict[T, List[T]] = {} _clients: Dict[T, List[T]] = {}
sources: Deque[T] = deque() sources: Deque[T] = deque()
search_res_len: int = 0 search_res_len: int = 0
for node, children in search_res: for snode, children in search_res:
search_res_len += 1 search_res_len += 1
if children: if children:
for child in children: for child in children:
_clients.setdefault(child, []).append(node) _clients.setdefault(child, []).append(snode)
if not deps_cache.get(node): if not deps_cache.get(snode):
sources.append(node) sources.append(snode)
if clients is not None: if clients is not None:
clients.update(_clients) clients.update(_clients)
...@@ -1090,7 +1116,7 @@ def general_toposort( ...@@ -1090,7 +1116,7 @@ def general_toposort(
def io_toposort( def io_toposort(
inputs: Iterable[Variable], inputs: Iterable[Variable],
outputs: Iterable[Variable], outputs: Reversible[Variable],
orderings: Optional[Dict[Apply, List[Apply]]] = None, orderings: Optional[Dict[Apply, List[Apply]]] = None,
clients: Optional[Dict[Variable, List[Variable]]] = None, clients: Optional[Dict[Variable, List[Variable]]] = None,
) -> List[Apply]: ) -> List[Apply]:
...@@ -1301,8 +1327,8 @@ def as_string( ...@@ -1301,8 +1327,8 @@ def as_string(
orph = list(orphans_between(i, outputs)) orph = list(orphans_between(i, outputs))
multi: Set = set() multi = set()
seen: Set = set() seen = set()
for output in outputs: for output in outputs:
op = output.owner op = output.owner
if op in seen: if op in seen:
...@@ -1318,11 +1344,11 @@ def as_string( ...@@ -1318,11 +1344,11 @@ def as_string(
multi.add(op2) multi.add(op2)
else: else:
seen.add(input.owner) seen.add(input.owner)
multi: Set = [x for x in multi] multi_list = [x for x in multi]
done: Set = set() done: Set = set()
def multi_index(x): def multi_index(x):
return multi.index(x) + 1 return multi_list.index(x) + 1
def describe(r): def describe(r):
if r.owner is not None and r not in i and r not in orph: if r.owner is not None and r not in i and r not in orph:
...@@ -1337,7 +1363,7 @@ def as_string( ...@@ -1337,7 +1363,7 @@ def as_string(
else: else:
done.add(op) done.add(op)
s = node_formatter(op, [describe(input) for input in op.inputs]) s = node_formatter(op, [describe(input) for input in op.inputs])
if op in multi: if op in multi_list:
return f"*{multi_index(op)} -> {s}" return f"*{multi_index(op)} -> {s}"
else: else:
return s return s
...@@ -1380,14 +1406,21 @@ def list_of_nodes( ...@@ -1380,14 +1406,21 @@ def list_of_nodes(
Output `Variable`\s. Output `Variable`\s.
""" """
def expand(o: Apply) -> List[Apply]:
return [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
]
return list( return list(
walk( cast(
[o.owner for o in outputs], Iterable[Apply],
lambda o: [ walk(
inp.owner [o.owner for o in outputs if o.owner],
for inp in o.inputs expand,
if inp.owner and not any(i in inp.owner.outputs for i in inputs) ),
],
) )
) )
...@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: ...@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return False return False
def equal_computations(xs, ys, in_xs=None, in_ys=None): def equal_computations(
xs: List[Union[np.ndarray, Variable]],
ys: List[Union[np.ndarray, Variable]],
in_xs: Optional[List[Variable]] = None,
in_ys: Optional[List[Variable]] = None,
) -> bool:
"""Checks if Aesara graphs represent the same computations. """Checks if Aesara graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The The two lists `xs`, `ys` should have the same number of entries. The
...@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
for x, y in zip(xs, ys): for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable): if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y) return cast(bool, np.array_equal(x, y))
if not isinstance(x, Variable): if not isinstance(x, Variable):
if isinstance(y, Constant): if isinstance(y, Constant):
return np.array_equal(y.data, x) return cast(bool, np.array_equal(y.data, x))
return False return False
if not isinstance(y, Variable): if not isinstance(y, Variable):
if isinstance(x, Constant): if isinstance(x, Constant):
return np.array_equal(x.data, y) return cast(bool, np.array_equal(x.data, y))
return False
if not isinstance(y, Variable):
return False return False
if x.owner and not y.owner: if x.owner and not y.owner:
return False return False
if y.owner and not x.owner: if y.owner and not x.owner:
return False return False
if x.owner: # Check above tell that y.owner eval to True too. if x.owner and y.owner:
if x.owner.outputs.index(x) != y.owner.outputs.index(y): if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False return False
if x not in in_xs and not (y.type.in_same_class(x.type)): if x not in in_xs and not (y.type.in_same_class(x.type)):
...@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
different = set() different: Set[Tuple[Variable, Variable]] = set()
for dx, dy in zip(xs, ys): for dx, dy in zip(xs, ys):
assert isinstance(dx, Variable)
# We checked above that both dx and dy have an owner or not # We checked above that both dx and dy have an owner or not
if not dx.owner: if not dx.owner:
if isinstance(dx, Constant) and isinstance(dy, Constant): if isinstance(dx, Constant) and isinstance(dy, Constant):
...@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Validate that each xs[i], ys[i] pair represents the same computation # Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)): for i in range(len(xs)):
if xs[i].owner: x_i: Variable = cast(Variable, xs[i])
if x_i.owner:
y_i: Variable = cast(Variable, ys[i])
# The case where pairs of x[i]s and y[i]s don't both have an owner # The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed. # have already been addressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different) is_equal = compare_nodes(x_i.owner, y_i.owner, common, different)
if not is_equal: if not is_equal:
return False return False
...@@ -1605,7 +1644,7 @@ def get_var_by_name( ...@@ -1605,7 +1644,7 @@ def get_var_by_name(
""" """
from aesara.graph.op import HasInnerGraph from aesara.graph.op import HasInnerGraph
def expand(r): def expand(r) -> Optional[List[Variable]]:
if r.owner: if r.owner:
res = list(r.owner.inputs) res = list(r.owner.inputs)
......
...@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple ...@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
def simple_extract_stack( def simple_extract_stack(
f=None, limit: Optional[int] = None, skips: Optional[Sequence[str]] = None f=None, limit: Optional[int] = None, skips: Optional[Sequence[str]] = None
) -> List[Tuple[Optional[str], int, str, str]]: ) -> List[Tuple[Optional[str], int, str, Optional[str]]]:
"""This is traceback.extract_stack from python 2.7 with this change: """This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache. - Comment the update of the cache.
...@@ -28,14 +28,12 @@ def simple_extract_stack( ...@@ -28,14 +28,12 @@ def simple_extract_stack(
skips = [] skips = []
if f is None: if f is None:
try: f = sys._getframe().f_back
raise ZeroDivisionError
except ZeroDivisionError:
f = sys.exc_info()[2].tb_frame.f_back
if limit is None: if limit is None:
if hasattr(sys, "tracebacklimit"): if hasattr(sys, "tracebacklimit"):
limit = sys.tracebacklimit limit = sys.tracebacklimit
trace: List[Tuple[Optional[str], int, str, str]] = [] trace: List[Tuple[Optional[str], int, str, Optional[str]]] = []
n = 0 n = 0
while f is not None and (limit is None or n < limit): while f is not None and (limit is None or n < limit):
lineno = f.f_lineno lineno = f.f_lineno
...@@ -43,7 +41,7 @@ def simple_extract_stack( ...@@ -43,7 +41,7 @@ def simple_extract_stack(
filename = co.co_filename filename = co.co_filename
name = co.co_name name = co.co_name
# linecache.checkcache(filename) # linecache.checkcache(filename)
line = linecache.getline(filename, lineno, f.f_globals) line: Optional[str] = linecache.getline(filename, lineno, f.f_globals)
if line: if line:
line = line.strip() line = line.strip()
else: else:
......
...@@ -115,14 +115,6 @@ check_untyped_defs = False ...@@ -115,14 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.graph.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.op] [mypy-aesara.graph.op]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
...@@ -179,6 +171,10 @@ check_untyped_defs = False ...@@ -179,6 +171,10 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.compile.builders]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.sharedvalue] [mypy-aesara.compile.sharedvalue]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论