提交 37444647 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to add missing inputs during FunctionGraph operations

上级 334d3fdf
"""A container for specifying and manipulating a graph with distinct inputs and outputs.""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time import time
import warnings
from collections import OrderedDict from collections import OrderedDict
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import toolbox, utils
from aesara.graph.basic import Apply, Constant, Variable, applys_between from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone as clone_graph from aesara.graph.basic import clone as clone_graph
from aesara.graph.basic import clone_get_equiv, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, io_toposort, vars_between
from aesara.graph.utils import TestValueError, get_variable_trace_string from aesara.graph.toolbox import AlreadyThere, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
...@@ -35,10 +36,10 @@ class MissingInputError(Exception): ...@@ -35,10 +36,10 @@ class MissingInputError(Exception):
if error_msg: if error_msg:
args = args + (error_msg,) args = args + (error_msg,)
s = "\n".join(args) # Needed to have the new line print correctly s = "\n".join(args) # Needed to have the new line print correctly
Exception.__init__(self, s) super().__init__(s)
class FunctionGraph(utils.MetaObject): class FunctionGraph(MetaObject):
""" """
A `FunctionGraph` represents a subgraph bound by a set of input variables and A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an Aesara function. a set of output variables, ie a subgraph that specifies an Aesara function.
...@@ -58,8 +59,8 @@ class FunctionGraph(utils.MetaObject): ...@@ -58,8 +59,8 @@ class FunctionGraph(utils.MetaObject):
both directions. both directions.
It can also be extended with new features using It can also be extended with new features using
`FunctionGraph.attach_feature`(<toolbox.Feature instance>). `FunctionGraph.attach_feature`(<Feature instance>).
See `toolbox.Feature` for event types and documentation. See `Feature` for event types and documentation.
Extra features allow the `FunctionGraph` to verify new properties of Extra features allow the `FunctionGraph` to verify new properties of
a graph as it is optimized. a graph as it is optimized.
...@@ -142,7 +143,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -142,7 +143,7 @@ class FunctionGraph(utils.MetaObject):
for f in features: for f in features:
self.attach_feature(f) self.attach_feature(f)
self.attach_feature(toolbox.ReplaceValidate()) self.attach_feature(ReplaceValidate())
self.inputs = [] self.inputs = []
for in_var in inputs: for in_var in inputs:
...@@ -311,7 +312,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -311,7 +312,7 @@ class FunctionGraph(utils.MetaObject):
for i, in_var in enumerate(apply_node.inputs): for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i))) removal_stack.append((in_var, (apply_node, i)))
def import_var(self, var, reason): def import_var(self, var, reason=None, import_missing=False):
"""Import variables into this `FunctionGraph`. """Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node. This will also import the `variable`'s `Apply` node.
...@@ -322,10 +323,13 @@ class FunctionGraph(utils.MetaObject): ...@@ -322,10 +323,13 @@ class FunctionGraph(utils.MetaObject):
The variable to be imported. The variable to be imported.
reason : str reason : str
The name of the optimization or operation in progress. The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
""" """
# Imports the owners of the variables # Imports the owners of the variables
if var.owner and var.owner not in self.apply_nodes: if var.owner and var.owner not in self.apply_nodes:
self.import_node(var.owner, reason=reason) self.import_node(var.owner, reason=reason, import_missing=import_missing)
elif ( elif (
var.owner is None var.owner is None
and not isinstance(var, Constant) and not isinstance(var, Constant)
...@@ -335,13 +339,16 @@ class FunctionGraph(utils.MetaObject): ...@@ -335,13 +339,16 @@ class FunctionGraph(utils.MetaObject):
if isinstance(var.type, NullType): if isinstance(var.type, NullType):
raise TypeError( raise TypeError(
"Computation graph contains a NaN. " + var.type.why_null f"Computation graph contains a NaN. {var.type.why_null}"
) )
raise MissingInputError("Undeclared input", variable=var) if import_missing:
self.add_input(var)
else:
raise MissingInputError(f"Undeclared input: {var}", variable=var)
self.setup_var(var) self.setup_var(var)
self.variables.add(var) self.variables.add(var)
def import_node(self, apply_node, check=True, reason=None): def import_node(self, apply_node, check=True, reason=None, import_missing=False):
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs. """Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters: Parameters:
...@@ -353,13 +360,13 @@ class FunctionGraph(utils.MetaObject): ...@@ -353,13 +360,13 @@ class FunctionGraph(utils.MetaObject):
the `FunctionGraph`. the `FunctionGraph`.
reason : str reason : str
The name of the optimization or operation in progress. The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
""" """
node = apply_node # We import the nodes in topological order. We only are interested in
# new nodes, so we use all variables we know of as if they were the
# We import the nodes in topological order. We only are interested # input set. (The functions in the graph module only use the input set
# in new nodes, so we use all variables we know of as if they were the input set. # to know where to stop going down.)
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = io_toposort(self.variables, apply_node.outputs) new_nodes = io_toposort(self.variables, apply_node.outputs)
if check: if check:
...@@ -370,15 +377,18 @@ class FunctionGraph(utils.MetaObject): ...@@ -370,15 +377,18 @@ class FunctionGraph(utils.MetaObject):
and not isinstance(var, Constant) and not isinstance(var, Constant)
and var not in self.inputs and var not in self.inputs
): ):
# Standard error message if import_missing:
error_msg = ( self.add_input(var)
f"Input {int(node.inputs.index(var))} of the graph (indices start " else:
f"from 0), used to compute {node}, was not " error_msg = (
"provided and not given a value. Use the " f"Input {node.inputs.index(var)} ({var})"
"Aesara flag exception_verbosity='high', " " of the graph (indices start "
"for more information on this error." f"from 0), used to compute {node}, was not "
) "provided and not given a value. Use the "
raise MissingInputError(error_msg, variable=var) "Aesara flag exception_verbosity='high', "
"for more information on this error."
)
raise MissingInputError(error_msg, variable=var)
for node in new_nodes: for node in new_nodes:
assert node not in self.apply_nodes assert node not in self.apply_nodes
...@@ -397,7 +407,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -397,7 +407,7 @@ class FunctionGraph(utils.MetaObject):
self.add_client(input, (node, i)) self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason) self.execute_callbacks("on_import", node, reason)
def change_input(self, node, i, new_var, reason=None): def change_input(self, node, i, new_var, reason=None, import_missing=False):
"""Change ``node.inputs[i]`` to `new_var`. """Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the ``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
...@@ -416,7 +426,8 @@ class FunctionGraph(utils.MetaObject): ...@@ -416,7 +426,8 @@ class FunctionGraph(utils.MetaObject):
The index in `node.inputs` that we want to change. The index in `node.inputs` that we want to change.
new_var : aesara.graph.basic.Variable new_var : aesara.graph.basic.Variable
The new variable to take the place of ``node.inputs[i]``. The new variable to take the place of ``node.inputs[i]``.
import_missing : bool
Add missing inputs instead of raising an exception.
""" """
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output": if node == "output":
...@@ -443,15 +454,15 @@ class FunctionGraph(utils.MetaObject): ...@@ -443,15 +454,15 @@ class FunctionGraph(utils.MetaObject):
if r is new_var: if r is new_var:
return return
self.import_var(new_var, reason=reason) self.import_var(new_var, reason=reason, import_missing=import_missing)
self.add_client(new_var, (node, i)) self.add_client(new_var, (node, i))
self.remove_client(r, (node, i), reason=reason) self.remove_client(r, (node, i), reason=reason)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid However it may
# However it may introduce cycles to the graph, in which case the # introduce cycles to the graph, in which case the transaction will be
# transaction will be reverted later. # reverted later.
self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason) self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason)
def replace(self, var, new_var, reason=None, verbose=None): def replace(self, var, new_var, reason=None, verbose=None, import_missing=False):
"""Replace a variable in the `FunctionGraph`. """Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`. This is the main interface to manipulate the subgraph in `FunctionGraph`.
...@@ -467,6 +478,8 @@ class FunctionGraph(utils.MetaObject): ...@@ -467,6 +478,8 @@ class FunctionGraph(utils.MetaObject):
The name of the optimization or operation in progress. The name of the optimization or operation in progress.
verbose : bool verbose : bool
Print `reason`, `var`, and `new_var`. Print `reason`, `var`, and `new_var`.
import_missing : bool
Import missing variables.
""" """
if verbose is None: if verbose is None:
...@@ -477,10 +490,16 @@ class FunctionGraph(utils.MetaObject): ...@@ -477,10 +490,16 @@ class FunctionGraph(utils.MetaObject):
new_var = var.type.filter_variable(new_var, allow_convert=True) new_var = var.type.filter_variable(new_var, allow_convert=True)
if var not in self.variables: if var not in self.variables:
# TODO: Raise an actual exception here.
# Old comment:
# this variable isn't in the graph... don't raise an # this variable isn't in the graph... don't raise an
# exception here, just return silently because it makes it # exception here, just return silently because it makes it
# easier to implement some optimizations for # easier to implement some optimizations for
# multiple-output ops # multiple-output ops
# raise ValueError()
warnings.warn(
f"Variable {var} cannot be replaced; it isn't in the FunctionGraph"
)
return return
if config.compute_test_value != "off": if config.compute_test_value != "off":
...@@ -503,12 +522,14 @@ class FunctionGraph(utils.MetaObject): ...@@ -503,12 +522,14 @@ class FunctionGraph(utils.MetaObject):
assert (node == "output" and self.outputs[i] is var) or ( assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var node.inputs[i] is var
) )
self.change_input(node, i, new_var, reason=reason) self.change_input(
node, i, new_var, reason=reason, import_missing=import_missing
)
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, **kwargs):
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list.""" """Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for var, new_var in pairs: for var, new_var in pairs:
self.replace(var, new_var, reason=reason) self.replace(var, new_var, **kwargs)
def attach_feature(self, feature): def attach_feature(self, feature):
""" """
...@@ -516,25 +537,25 @@ class FunctionGraph(utils.MetaObject): ...@@ -516,25 +537,25 @@ class FunctionGraph(utils.MetaObject):
on_attach callback. on_attach callback.
""" """
# Filter out literally identical features # Filter out literally identical `Feature`s
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
# Filter out functionally identical features. # Filter out functionally identical `Feature`s.
# Features may use their on_attach method to raise # `Feature`s may use their `on_attach` method to raise
# toolbox.AlreadyThere if they detect that some # `AlreadyThere` if they detect that some
# installed feature does the same thing already # installed `Feature` does the same thing already
attach = getattr(feature, "on_attach", None) attach = getattr(feature, "on_attach", None)
if attach is not None: if attach is not None:
try: try:
attach(self) attach(self)
except toolbox.AlreadyThere: except AlreadyThere:
return return
self.execute_callbacks_times.setdefault(feature, 0) self.execute_callbacks_times.setdefault(feature, 0)
# it would be nice if we could require a specific class instead of # It would be nice if we could require a specific class instead of
# a "workalike" so we could do actual error checking # a "workalike" so we could do actual error checking
# if not isinstance(feature, toolbox.Feature): # if not isinstance(feature, Feature):
# raise TypeError("Expected graph.toolbox.Feature instance, got "+\ # raise TypeError("Expected Feature instance, got "+\
# str(type(feature))) # str(type(feature)))
# Add the feature # Add the feature
......
...@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer):
# Only need to check one of the var of each pairs. # Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them. # If it is a Constant, the other must also be a Constant as we merge them.
if all([isinstance(old, Constant) for old, new in pairs]): if all([isinstance(old, Constant) for old, new in pairs]):
fgraph.replace_all(pairs, "MergeOptimizer") fgraph.replace_all(pairs, reason="MergeOptimizer")
else: else:
fgraph.replace_all_validate(pairs, "MergeOptimizer") fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
except InconsistencyError: except InconsistencyError:
success = False success = False
nb_fail += 1 nb_fail += 1
......
...@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator): ...@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator):
del fgraph.replace_all_validate del fgraph.replace_all_validate
del fgraph.replace_all_validate_remove del fgraph.replace_all_validate_remove
def replace_validate(self, fgraph, r, new_r, reason=None): def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason) self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs)
def replace_all_validate(self, fgraph, replacements, reason=None, verbose=None): def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
if verbose is None: if verbose is None:
verbose = config.optimizer_verbose verbose = config.optimizer_verbose
...@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator): ...@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator):
for r, new_r in replacements: for r, new_r in replacements:
try: try:
fgraph.replace(r, new_r, reason=reason, verbose=False) fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
except Exception as e: except Exception as e:
msg = str(e) msg = str(e)
s1 = "The type of the replacement must be the same" s1 = "The type of the replacement must be the same"
...@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator): ...@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator):
return chk return chk
def replace_all_validate_remove( def replace_all_validate_remove(
self, fgraph, replacements, remove, reason=None, warn=True self, fgraph, replacements, remove, reason=None, warn=True, **kwargs
): ):
""" """
As replace_all_validate, revert the replacement if the ops As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. Also print a warning. in the list remove are still in the graph. Also print a warning.
""" """
chk = fgraph.replace_all_validate(replacements, reason) chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
self._nodes_removed.update(remove) self._nodes_removed.update(remove)
for rm in remove: for rm in remove:
if rm in fgraph.apply_nodes or rm in fgraph.variables: if rm in fgraph.apply_nodes or rm in fgraph.variables:
......
...@@ -111,20 +111,26 @@ class TestFunctionGraph: ...@@ -111,20 +111,26 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var5 = MyVariable("var5") var8 = MyVariable("var8")
var6 = op2(var5) var6 = op2(var8)
with pytest.raises(MissingInputError): with pytest.raises(MissingInputError):
fg.import_node(var6.owner) fg.import_node(var6.owner)
var6 = op2(var2) assert var8 not in fg.variables
assert not hasattr(var6.owner.tag, "imported_by")
fg.import_node(var6.owner)
assert hasattr(var6.owner.tag, "imported_by") fg.import_node(var6.owner, import_missing=True)
assert var6 in fg.variables assert var8 in fg.inputs
assert var6.owner in fg.apply_nodes assert var6.owner in fg.apply_nodes
assert (var6.owner, 0) in fg.get_clients(var2)
var7 = op2(var2)
assert not hasattr(var7.owner.tag, "imported_by")
fg.import_node(var7.owner)
assert hasattr(var7.owner.tag, "imported_by")
assert var7 in fg.variables
assert var7.owner in fg.apply_nodes
assert (var7.owner, 0) in fg.get_clients(var2)
def test_import_var(self): def test_import_var(self):
...@@ -135,12 +141,17 @@ class TestFunctionGraph: ...@@ -135,12 +141,17 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var0 = MyVariable("var0")
with pytest.raises(MissingInputError): with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
# We can't import a new `FunctionGraph` input (i.e. something # We can't import a new `FunctionGraph` input (i.e. something
# without an owner) # without an owner), at least not without setting `import_missing`
fg.import_var(var0, "testing") fg.import_var(var0, "testing")
fg.import_var(var0, import_missing=True)
assert var0 in fg.inputs
var5 = op2() var5 = op2()
# We can import variables with owners # We can import variables with owners
fg.import_var(var5, "testing") fg.import_var(var5, "testing")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论