提交 01020a18 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add memo option, automatic inputs, and copy options to FunctionGraph

上级 52bad109
......@@ -2,14 +2,14 @@
import time
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone as clone_graph
from aesara.graph.basic import clone_get_equiv, io_toposort, vars_between
from aesara.graph.toolbox import AlreadyThere, ReplaceValidate
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.toolbox import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.misc.ordered_set import OrderedSet
......@@ -44,23 +44,23 @@ class FunctionGraph(MetaObject):
A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an Aesara function.
The inputs list should contain all the inputs on which the outputs depend.
`Variable`s of type `Constant` are not counted as inputs.
``Variable``s of type ``Constant`` are not counted as inputs.
The `FunctionGraph` supports the replace operation which allows to replace
a variable in the subgraph by another, e.g. replace ``(x + x).out`` by ``(2
* x).out``. This is the basis for optimization in Aesara.
a variable in the subgraph by another, e.g. replace ``(x + x).out`` by
``(2 * x).out``. This is the basis for optimization in Aesara.
This class is also responsible for verifying that a graph is valid
(ie, all the dtypes and broadcast patterns are compatible with the
way the the `Variable`s are used) and for tracking the `Variable`s with
a `clients` field that specifies which `Apply` nodes use the `Variable`.
The `clients` field combined with the `Variable.owner` field and the
`Apply` nodes' `Apply.inputs` field allows the graph to be traversed in
way the ``Variable``s are used) and for tracking the ``Variable``s with
a ``clients`` field that specifies which ``Apply`` nodes use the ``Variable``.
The ``clients`` field combined with the ``Variable.owner`` field and the
``Apply`` nodes' ``Apply.inputs`` field allows the graph to be traversed in
both directions.
It can also be extended with new features using
`FunctionGraph.attach_feature`(<Feature instance>).
See `Feature` for event types and documentation.
``FunctionGraph.attach_feature(<Feature instance>)``.
See ``Feature`` for event types and documentation.
Extra features allow the `FunctionGraph` to verify new properties of
a graph as it is optimized.
......@@ -73,52 +73,59 @@ class FunctionGraph(MetaObject):
This class keeps a pointer to the inputs and outputs, and also modifies
them.
Parameters
----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Notes
-----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
"""
def __init__(self, inputs, outputs, features=None, clone=True, update_mapping=None):
def __init__(
self,
inputs: Optional[List[Variable]] = None,
outputs: Optional[List[Variable]] = None,
features: Optional[List[Feature]] = None,
clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None,
copy_inputs: bool = True,
copy_orphans: bool = True,
):
"""
Create an FunctionGraph which operates on the subgraph bound by the
inputs and outputs sets.
Create a `FunctionGraph` which operates on the subgraph between the
`inputs` and `outputs`.
Parameters
----------
inputs : list of aesara.graph.basic.Variable
Inputs nodes of the graph, usually declared by the user
outputs : list of aesara.graph.basic.Variable
Outputs nodes of the graph.
clone : boolean
If true, we will clone the graph. This is useful to remove the
constant cache problem.
features : list of aesara.graph.toolbox.Feature
inputs
Input variables of the graph.
outputs
Output variables of the graph.
clone
If ``True``, the graph will be cloned.
features
A list of features to be added to the `FunctionGraph`.
update_mapping : dict
Mapping between the inputs with updates and the outputs
update_mapping
Mapping between the `inputs` with updates and the `outputs`
corresponding to their updates.
memo
See ``clone_get_equiv``.
copy_inputs
See ``clone_get_equiv``.
copy_orphans
See ``clone_get_equiv``.
"""
if outputs is None:
raise ValueError("No outputs specified")
if clone:
inputs, outputs = clone_graph(inputs, outputs)
if not isinstance(inputs, list):
raise TypeError("Argument `inputs` should be a list")
if inputs is None:
inputs = [i for i in graph_inputs(outputs)]
if not isinstance(outputs, list):
raise TypeError("Argument `outputs` should be a list")
if clone:
memo = clone_get_equiv(
inputs,
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
memo=memo,
)
outputs = [memo[o] for o in outputs]
inputs = [memo[i] for i in inputs]
self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
......@@ -165,7 +172,7 @@ class FunctionGraph(MetaObject):
self.profile = None
self.update_mapping = update_mapping
def add_input(self, var, check=True):
def add_input(self, var: Variable, check: bool = True) -> NoReturn:
"""Add a new variable as an input to this `FunctionGraph`.
Parameters
......@@ -180,7 +187,7 @@ class FunctionGraph(MetaObject):
self.setup_var(var)
self.variables.add(var)
def setup_var(self, var):
def setup_var(self, var: Variable) -> NoReturn:
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
......@@ -190,7 +197,7 @@ class FunctionGraph(MetaObject):
"""
self.clients.setdefault(var, [])
def setup_node(self, node):
def setup_node(self, node: Apply) -> NoReturn:
"""Set up node so it belongs to this `FunctionGraph`.
Parameters
......@@ -214,14 +221,8 @@ class FunctionGraph(MetaObject):
" the values must be tuples or lists."
)
def disown(self):
"""
Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called.
"""
def disown(self) -> NoReturn:
"""Clear internal variables."""
for f in self._features:
self.remove_feature(f)
self.clients = {}
......@@ -232,11 +233,11 @@ class FunctionGraph(MetaObject):
self.profile = None
self.update_mapping = None
def get_clients(self, var):
def get_clients(self, var: Variable) -> List[Tuple[Apply, int]]:
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`."""
return self.clients[var]
def add_client(self, var, new_client):
def add_client(self, var: Variable, new_client: Tuple[Apply, int]) -> NoReturn:
"""Update the clients of `var` with `new_clients`.
Parameters
......@@ -248,7 +249,9 @@ class FunctionGraph(MetaObject):
"""
self.clients[var].append(new_client)
def remove_client(self, var, client_to_remove, reason=None):
def remove_client(
self, var: Variable, client_to_remove: Tuple[Apply, int], reason: str = None
) -> NoReturn:
"""Recursively removes clients of a variable.
This is the main method to remove variables or `Apply` nodes from
......@@ -312,7 +315,9 @@ class FunctionGraph(MetaObject):
for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i)))
def import_var(self, var, reason=None, import_missing=False):
def import_var(
self, var: Variable, reason: str = None, import_missing: bool = False
) -> NoReturn:
"""Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node.
......@@ -348,7 +353,13 @@ class FunctionGraph(MetaObject):
self.setup_var(var)
self.variables.add(var)
def import_node(self, apply_node, check=True, reason=None, import_missing=False):
def import_node(
self,
apply_node: Apply,
check: bool = True,
reason: str = None,
import_missing: bool = False,
) -> NoReturn:
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters:
......@@ -407,7 +418,14 @@ class FunctionGraph(MetaObject):
self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason)
def change_input(self, node, i, new_var, reason=None, import_missing=False):
def change_input(
self,
node: Apply,
i: int,
new_var: Variable,
reason: str = None,
import_missing: bool = False,
) -> NoReturn:
"""Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
......@@ -462,7 +480,14 @@ class FunctionGraph(MetaObject):
# reverted later.
self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason)
def replace(self, var, new_var, reason=None, verbose=None, import_missing=False):
def replace(
self,
var: Variable,
new_var: Variable,
reason: str = None,
verbose: bool = None,
import_missing: bool = False,
) -> NoReturn:
"""Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
......@@ -526,12 +551,12 @@ class FunctionGraph(MetaObject):
node, i, new_var, reason=reason, import_missing=import_missing
)
def replace_all(self, pairs, **kwargs):
def replace_all(self, pairs: List[Tuple[Variable, Variable]], **kwargs) -> NoReturn:
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, **kwargs)
def attach_feature(self, feature):
def attach_feature(self, feature: Feature) -> NoReturn:
"""
Adds a graph.toolbox.Feature to this function_graph and triggers its
on_attach callback.
......@@ -561,7 +586,7 @@ class FunctionGraph(MetaObject):
# Add the feature
self._features.append(feature)
def remove_feature(self, feature):
def remove_feature(self, feature: Feature) -> NoReturn:
"""
Removes the feature from the graph.
......@@ -578,7 +603,7 @@ class FunctionGraph(MetaObject):
if detach is not None:
detach(self)
def execute_callbacks(self, name, *args, **kwargs):
def execute_callbacks(self, name: str, *args, **kwargs) -> NoReturn:
"""Execute callbacks
Calls `getattr(feature, name)(*args)` for each feature which has
......@@ -599,7 +624,7 @@ class FunctionGraph(MetaObject):
self.execute_callbacks_times[feature] += time.time() - tf0
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args):
def collect_callbacks(self, name: str, *args) -> Dict[Feature, Any]:
"""Collects callbacks
Returns a dictionary d such that
......@@ -615,7 +640,7 @@ class FunctionGraph(MetaObject):
d[feature] = fn(*args)
return d
def toposort(self):
def toposort(self) -> List[Apply]:
"""Toposort
Return an ordering of the graph's Apply nodes such that
......@@ -643,7 +668,7 @@ class FunctionGraph(MetaObject):
return order
def orderings(self):
def orderings(self) -> Dict[Apply, List[Apply]]:
"""Return `dict` `d` s.t. `d[node]` is a list of nodes that must be evaluated before `node` itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
......@@ -689,7 +714,7 @@ class FunctionGraph(MetaObject):
ords.setdefault(node, []).extend(prereqs)
return ords
def check_integrity(self):
def check_integrity(self) -> NoReturn:
"""
Call this for a diagnosis if things go awry.
......@@ -745,14 +770,16 @@ class FunctionGraph(MetaObject):
def __repr__(self):
return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"
def clone(self, check_integrity=True):
def clone(self, check_integrity=True) -> "FunctionGraph":
"""
Clone the graph and get a memo( a dict )that map old node to new node
"""
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True, attach_feature=True):
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Union["FunctionGraph", Dict[Variable, Variable]]:
"""Clone the graph and get a dict that maps old nodes to new ones
Parameters:
......@@ -810,7 +837,7 @@ class FunctionGraph(MetaObject):
if hasattr(feature, "unpickle"):
feature.unpickle(self)
def __contains__(self, item):
def __contains__(self, item: Union[Variable, Apply]) -> bool:
if isinstance(item, Variable):
return item in self.variables
elif isinstance(item, Apply):
......
......@@ -41,6 +41,10 @@ class TestFunctionGraph:
var3 = op1(var1)
FunctionGraph([var3], [var2], clone=False)
with pytest.raises(ValueError):
var3 = op1(var1)
FunctionGraph([var3], clone=False)
def test_init(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
......@@ -58,6 +62,19 @@ class TestFunctionGraph:
assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)]
assert fg.get_clients(var4) == [("output", 1)]
fg = FunctionGraph(outputs=[var3, var4], clone=False)
assert fg.inputs == [var1, var2]
memo = {}
fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo)
assert memo[var1].type == var1.type
assert memo[var1].name == var1.name
assert memo[var2].type == var2.type
assert memo[var2].name == var2.name
assert var3 in memo
assert var4 in memo
def test_remove_client(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
......
......@@ -58,7 +58,7 @@ class MyOp(Op):
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
outputs[0] = np.array(inputs)
outputs[0] = np.array(inputs, dtype=np.object)
def __str__(self):
return self.name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论