提交 8b86e270 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.fg

上级 3a764edd
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from typing_extensions import Literal
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import Apply, Constant, Node, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
......@@ -13,6 +26,10 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError
from aesara.misc.ordered_set import OrderedSet
ApplyOrOutput = Union[Apply, Literal["output"]]
ClientType = Tuple[ApplyOrOutput, int]
class FunctionGraph(MetaObject):
r"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and
......@@ -87,17 +104,17 @@ class FunctionGraph(MetaObject):
inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)]
if clone:
memo = clone_get_equiv(
_memo = clone_get_equiv(
inputs,
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
memo=memo,
memo=cast(Dict[Node, Node], memo),
)
outputs = [memo[o] for o in outputs]
inputs = [memo[i] for i in inputs]
outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs]
self.execute_callbacks_time = 0
self.execute_callbacks_time: float = 0.0
self.execute_callbacks_times: Dict[Feature, float] = {}
if features is None:
......@@ -116,7 +133,7 @@ class FunctionGraph(MetaObject):
self.inputs: List[Variable] = []
self.outputs: List[Variable] = list(outputs)
self.clients: Dict[Variable, List[Tuple[Union[Apply, str], int]]] = {}
self.clients: Dict[Variable, List[ClientType]] = {}
for f in features:
self.attach_feature(f)
......@@ -167,11 +184,11 @@ class FunctionGraph(MetaObject):
"""
self.clients.setdefault(var, [])
def get_clients(self, var: Variable) -> List[Tuple[Apply, int]]:
def get_clients(self, var: Variable) -> List[ClientType]:
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`."""
return self.clients[var]
def add_client(self, var: Variable, new_client: Tuple[Apply, int]) -> None:
def add_client(self, var: Variable, new_client: ClientType) -> None:
"""Update the clients of `var` with `new_clients`.
Parameters
......@@ -182,10 +199,17 @@ class FunctionGraph(MetaObject):
A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`.
"""
if not isinstance(new_client[0], Apply) and new_client[0] != "output":
raise TypeError(
'The first entry of `new_client` must be an `Apply` node or the string `"output"`'
)
self.clients[var].append(new_client)
def remove_client(
self, var: Variable, client_to_remove: Tuple[Apply, int], reason: str = None
self,
var: Variable,
client_to_remove: ClientType,
reason: Optional[str] = None,
) -> None:
"""Recursively remove clients of a variable.
......@@ -251,7 +275,7 @@ class FunctionGraph(MetaObject):
removal_stack.append((in_var, (apply_node, i)))
def import_var(
self, var: Variable, reason: str = None, import_missing: bool = False
self, var: Variable, reason: Optional[str] = None, import_missing: bool = False
) -> None:
"""Import variables into this `FunctionGraph`.
......@@ -292,7 +316,7 @@ class FunctionGraph(MetaObject):
self,
apply_node: Apply,
check: bool = True,
reason: str = None,
reason: Optional[str] = None,
import_missing: bool = False,
) -> None:
"""Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs.
......@@ -354,10 +378,10 @@ class FunctionGraph(MetaObject):
def change_node_input(
self,
node: Union[Apply, str],
node: ApplyOrOutput,
i: int,
new_var: Variable,
reason: str = None,
reason: Optional[str] = None,
import_missing: bool = False,
check: bool = True,
) -> None:
......@@ -398,6 +422,7 @@ class FunctionGraph(MetaObject):
)
self.outputs[i] = new_var
else:
assert isinstance(node, Apply)
r = node.inputs[i]
if check and not r.type.is_super(new_var.type):
raise TypeError(
......@@ -421,8 +446,8 @@ class FunctionGraph(MetaObject):
self,
var: Variable,
new_var: Variable,
reason: str = None,
verbose: bool = None,
reason: Optional[str] = None,
verbose: Optional[bool] = None,
import_missing: bool = False,
) -> None:
"""Replace a variable in the `FunctionGraph`.
......@@ -481,7 +506,7 @@ class FunctionGraph(MetaObject):
for node, i in list(self.clients[var]):
assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var
isinstance(node, Apply) and node.inputs[i] is var
)
self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing
......@@ -640,12 +665,12 @@ class FunctionGraph(MetaObject):
"""Check the integrity of nodes in the graph."""
nodes = set(applys_between(self.inputs, self.outputs))
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
nodes_missing = nodes.difference(self.apply_nodes)
nodes_excess = self.apply_nodes.difference(nodes)
raise Exception(
"The nodes are inappropriately cached. missing, in excess: ",
missing,
excess,
nodes_missing,
nodes_excess,
)
for node in nodes:
for i, variable in enumerate(node.inputs):
......@@ -656,12 +681,12 @@ class FunctionGraph(MetaObject):
)
variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables:
missing = variables.difference(self.variables)
excess = self.variables.difference(variables)
vars_missing = variables.difference(self.variables)
vars_excess = self.variables.difference(variables)
raise Exception(
"The variables are inappropriately cached. missing, in excess: ",
missing,
excess,
vars_missing,
vars_excess,
)
for variable in variables:
if (
......@@ -670,20 +695,23 @@ class FunctionGraph(MetaObject):
and not isinstance(variable, Constant)
):
raise Exception(f"Undeclared input: {variable}")
for node, i in self.clients[variable]:
if node == "output":
for cl_node, i in self.clients[variable]:
if cl_node == "output":
if self.outputs[i] is not variable:
raise Exception(
f"Inconsistent clients list: {variable}, {self.outputs[i]}"
)
continue
if node not in nodes:
assert isinstance(cl_node, Apply)
if cl_node not in nodes:
raise Exception(
f"Client not in FunctionGraph: {variable}, {(node, i)}"
f"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
)
if node.inputs[i] is not variable:
if cl_node.inputs[i] is not variable:
raise Exception(
f"Inconsistent clients list: {variable}, {node.inputs[i]}"
f"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
)
def __repr__(self):
......@@ -695,7 +723,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Variable, Variable]]:
) -> Tuple["FunctionGraph", Dict[Node, Node]]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters
......@@ -717,8 +745,8 @@ class FunctionGraph(MetaObject):
if check_integrity:
self.check_integrity()
e = FunctionGraph(
[equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs],
[cast(Variable, equiv[i]) for i in self.inputs],
[cast(Variable, equiv[o]) for o in self.outputs],
clone=False,
)
if check_integrity:
......
......@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.fg]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.opt]
ignore_errors = True
check_untyped_defs = False
......
......@@ -355,12 +355,17 @@ class TestFunctionGraph:
fg.remove_client(var4, ("output", 1))
with pytest.raises(TypeError, match="The first entry of.*"):
fg.add_client(var4, (None, 0))
var7 = op1(var4)
with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
fg.add_client(var4, (var6.owner, 0))
fg.add_client(var4, (var7.owner, 0))
fg.check_integrity()
fg.remove_client(var4, (var6.owner, 0))
fg.remove_client(var4, (var7.owner, 0))
with pytest.raises(Exception, match="Inconsistent clients list.*"):
fg.add_client(var4, (var3.owner, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论