提交 8b86e270 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.graph.fg

上级 3a764edd
"""A container for specifying and manipulating a graph with distinct inputs and outputs.""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from typing_extensions import Literal
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, applys_between from aesara.graph.basic import Apply, Constant, Node, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
...@@ -13,6 +26,10 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError ...@@ -13,6 +26,10 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
ApplyOrOutput = Union[Apply, Literal["output"]]
ClientType = Tuple[ApplyOrOutput, int]
class FunctionGraph(MetaObject): class FunctionGraph(MetaObject):
r""" r"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and A `FunctionGraph` represents a subgraph bound by a set of input variables and
...@@ -87,17 +104,17 @@ class FunctionGraph(MetaObject): ...@@ -87,17 +104,17 @@ class FunctionGraph(MetaObject):
inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)] inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)]
if clone: if clone:
memo = clone_get_equiv( _memo = clone_get_equiv(
inputs, inputs,
outputs, outputs,
copy_inputs=copy_inputs, copy_inputs=copy_inputs,
copy_orphans=copy_orphans, copy_orphans=copy_orphans,
memo=memo, memo=cast(Dict[Node, Node], memo),
) )
outputs = [memo[o] for o in outputs] outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [memo[i] for i in inputs] inputs = [cast(Variable, _memo[i]) for i in inputs]
self.execute_callbacks_time = 0 self.execute_callbacks_time: float = 0.0
self.execute_callbacks_times: Dict[Feature, float] = {} self.execute_callbacks_times: Dict[Feature, float] = {}
if features is None: if features is None:
...@@ -116,7 +133,7 @@ class FunctionGraph(MetaObject): ...@@ -116,7 +133,7 @@ class FunctionGraph(MetaObject):
self.inputs: List[Variable] = [] self.inputs: List[Variable] = []
self.outputs: List[Variable] = list(outputs) self.outputs: List[Variable] = list(outputs)
self.clients: Dict[Variable, List[Tuple[Union[Apply, str], int]]] = {} self.clients: Dict[Variable, List[ClientType]] = {}
for f in features: for f in features:
self.attach_feature(f) self.attach_feature(f)
...@@ -167,11 +184,11 @@ class FunctionGraph(MetaObject): ...@@ -167,11 +184,11 @@ class FunctionGraph(MetaObject):
""" """
self.clients.setdefault(var, []) self.clients.setdefault(var, [])
def get_clients(self, var: Variable) -> List[Tuple[Apply, int]]: def get_clients(self, var: Variable) -> List[ClientType]:
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`.""" """Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`."""
return self.clients[var] return self.clients[var]
def add_client(self, var: Variable, new_client: Tuple[Apply, int]) -> None: def add_client(self, var: Variable, new_client: ClientType) -> None:
"""Update the clients of `var` with `new_clients`. """Update the clients of `var` with `new_clients`.
Parameters Parameters
...@@ -182,10 +199,17 @@ class FunctionGraph(MetaObject): ...@@ -182,10 +199,17 @@ class FunctionGraph(MetaObject):
A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`. A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`.
""" """
if not isinstance(new_client[0], Apply) and new_client[0] != "output":
raise TypeError(
'The first entry of `new_client` must be an `Apply` node or the string `"output"`'
)
self.clients[var].append(new_client) self.clients[var].append(new_client)
def remove_client( def remove_client(
self, var: Variable, client_to_remove: Tuple[Apply, int], reason: str = None self,
var: Variable,
client_to_remove: ClientType,
reason: Optional[str] = None,
) -> None: ) -> None:
"""Recursively remove clients of a variable. """Recursively remove clients of a variable.
...@@ -251,7 +275,7 @@ class FunctionGraph(MetaObject): ...@@ -251,7 +275,7 @@ class FunctionGraph(MetaObject):
removal_stack.append((in_var, (apply_node, i))) removal_stack.append((in_var, (apply_node, i)))
def import_var( def import_var(
self, var: Variable, reason: str = None, import_missing: bool = False self, var: Variable, reason: Optional[str] = None, import_missing: bool = False
) -> None: ) -> None:
"""Import variables into this `FunctionGraph`. """Import variables into this `FunctionGraph`.
...@@ -292,7 +316,7 @@ class FunctionGraph(MetaObject): ...@@ -292,7 +316,7 @@ class FunctionGraph(MetaObject):
self, self,
apply_node: Apply, apply_node: Apply,
check: bool = True, check: bool = True,
reason: str = None, reason: Optional[str] = None,
import_missing: bool = False, import_missing: bool = False,
) -> None: ) -> None:
"""Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs. """Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs.
...@@ -354,10 +378,10 @@ class FunctionGraph(MetaObject): ...@@ -354,10 +378,10 @@ class FunctionGraph(MetaObject):
def change_node_input( def change_node_input(
self, self,
node: Union[Apply, str], node: ApplyOrOutput,
i: int, i: int,
new_var: Variable, new_var: Variable,
reason: str = None, reason: Optional[str] = None,
import_missing: bool = False, import_missing: bool = False,
check: bool = True, check: bool = True,
) -> None: ) -> None:
...@@ -398,6 +422,7 @@ class FunctionGraph(MetaObject): ...@@ -398,6 +422,7 @@ class FunctionGraph(MetaObject):
) )
self.outputs[i] = new_var self.outputs[i] = new_var
else: else:
assert isinstance(node, Apply)
r = node.inputs[i] r = node.inputs[i]
if check and not r.type.is_super(new_var.type): if check and not r.type.is_super(new_var.type):
raise TypeError( raise TypeError(
...@@ -421,8 +446,8 @@ class FunctionGraph(MetaObject): ...@@ -421,8 +446,8 @@ class FunctionGraph(MetaObject):
self, self,
var: Variable, var: Variable,
new_var: Variable, new_var: Variable,
reason: str = None, reason: Optional[str] = None,
verbose: bool = None, verbose: Optional[bool] = None,
import_missing: bool = False, import_missing: bool = False,
) -> None: ) -> None:
"""Replace a variable in the `FunctionGraph`. """Replace a variable in the `FunctionGraph`.
...@@ -481,7 +506,7 @@ class FunctionGraph(MetaObject): ...@@ -481,7 +506,7 @@ class FunctionGraph(MetaObject):
for node, i in list(self.clients[var]): for node, i in list(self.clients[var]):
assert (node == "output" and self.outputs[i] is var) or ( assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var isinstance(node, Apply) and node.inputs[i] is var
) )
self.change_node_input( self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing node, i, new_var, reason=reason, import_missing=import_missing
...@@ -640,12 +665,12 @@ class FunctionGraph(MetaObject): ...@@ -640,12 +665,12 @@ class FunctionGraph(MetaObject):
"""Check the integrity of nodes in the graph.""" """Check the integrity of nodes in the graph."""
nodes = set(applys_between(self.inputs, self.outputs)) nodes = set(applys_between(self.inputs, self.outputs))
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) nodes_missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) nodes_excess = self.apply_nodes.difference(nodes)
raise Exception( raise Exception(
"The nodes are inappropriately cached. missing, in excess: ", "The nodes are inappropriately cached. missing, in excess: ",
missing, nodes_missing,
excess, nodes_excess,
) )
for node in nodes: for node in nodes:
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
...@@ -656,12 +681,12 @@ class FunctionGraph(MetaObject): ...@@ -656,12 +681,12 @@ class FunctionGraph(MetaObject):
) )
variables = set(vars_between(self.inputs, self.outputs)) variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
missing = variables.difference(self.variables) vars_missing = variables.difference(self.variables)
excess = self.variables.difference(variables) vars_excess = self.variables.difference(variables)
raise Exception( raise Exception(
"The variables are inappropriately cached. missing, in excess: ", "The variables are inappropriately cached. missing, in excess: ",
missing, vars_missing,
excess, vars_excess,
) )
for variable in variables: for variable in variables:
if ( if (
...@@ -670,20 +695,23 @@ class FunctionGraph(MetaObject): ...@@ -670,20 +695,23 @@ class FunctionGraph(MetaObject):
and not isinstance(variable, Constant) and not isinstance(variable, Constant)
): ):
raise Exception(f"Undeclared input: {variable}") raise Exception(f"Undeclared input: {variable}")
for node, i in self.clients[variable]: for cl_node, i in self.clients[variable]:
if node == "output": if cl_node == "output":
if self.outputs[i] is not variable: if self.outputs[i] is not variable:
raise Exception( raise Exception(
f"Inconsistent clients list: {variable}, {self.outputs[i]}" f"Inconsistent clients list: {variable}, {self.outputs[i]}"
) )
continue continue
if node not in nodes:
assert isinstance(cl_node, Apply)
if cl_node not in nodes:
raise Exception( raise Exception(
f"Client not in FunctionGraph: {variable}, {(node, i)}" f"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
) )
if node.inputs[i] is not variable: if cl_node.inputs[i] is not variable:
raise Exception( raise Exception(
f"Inconsistent clients list: {variable}, {node.inputs[i]}" f"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
) )
def __repr__(self): def __repr__(self):
...@@ -695,7 +723,7 @@ class FunctionGraph(MetaObject): ...@@ -695,7 +723,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv( def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Variable, Variable]]: ) -> Tuple["FunctionGraph", Dict[Node, Node]]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes. """Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters Parameters
...@@ -717,8 +745,8 @@ class FunctionGraph(MetaObject): ...@@ -717,8 +745,8 @@ class FunctionGraph(MetaObject):
if check_integrity: if check_integrity:
self.check_integrity() self.check_integrity()
e = FunctionGraph( e = FunctionGraph(
[equiv[i] for i in self.inputs], [cast(Variable, equiv[i]) for i in self.inputs],
[equiv[o] for o in self.outputs], [cast(Variable, equiv[o]) for o in self.outputs],
clone=False, clone=False,
) )
if check_integrity: if check_integrity:
......
...@@ -115,10 +115,6 @@ check_untyped_defs = False ...@@ -115,10 +115,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.graph.fg]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.opt] [mypy-aesara.graph.opt]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
...@@ -355,12 +355,17 @@ class TestFunctionGraph: ...@@ -355,12 +355,17 @@ class TestFunctionGraph:
fg.remove_client(var4, ("output", 1)) fg.remove_client(var4, ("output", 1))
with pytest.raises(TypeError, match="The first entry of.*"):
fg.add_client(var4, (None, 0))
var7 = op1(var4)
with pytest.raises(Exception, match="Client not in FunctionGraph.*"): with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
fg.add_client(var4, (var6.owner, 0)) fg.add_client(var4, (var7.owner, 0))
fg.check_integrity() fg.check_integrity()
fg.remove_client(var4, (var6.owner, 0)) fg.remove_client(var4, (var7.owner, 0))
with pytest.raises(Exception, match="Inconsistent clients list.*"): with pytest.raises(Exception, match="Inconsistent clients list.*"):
fg.add_client(var4, (var3.owner, 0)) fg.add_client(var4, (var3.owner, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论