提交 37444647 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to add missing inputs during FunctionGraph operations

上级 334d3fdf
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
import warnings
from collections import OrderedDict
import aesara
from aesara.configdefaults import config
from aesara.graph import toolbox, utils
from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone as clone_graph
from aesara.graph.basic import clone_get_equiv, io_toposort, vars_between
from aesara.graph.utils import TestValueError, get_variable_trace_string
from aesara.graph.toolbox import AlreadyThere, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.misc.ordered_set import OrderedSet
......@@ -35,10 +36,10 @@ class MissingInputError(Exception):
if error_msg:
args = args + (error_msg,)
s = "\n".join(args) # Needed to have the new line print correctly
Exception.__init__(self, s)
super().__init__(s)
class FunctionGraph(utils.MetaObject):
class FunctionGraph(MetaObject):
"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an Aesara function.
......@@ -58,8 +59,8 @@ class FunctionGraph(utils.MetaObject):
both directions.
It can also be extended with new features using
`FunctionGraph.attach_feature`(<toolbox.Feature instance>).
See `toolbox.Feature` for event types and documentation.
`FunctionGraph.attach_feature`(<Feature instance>).
See `Feature` for event types and documentation.
Extra features allow the `FunctionGraph` to verify new properties of
a graph as it is optimized.
......@@ -142,7 +143,7 @@ class FunctionGraph(utils.MetaObject):
for f in features:
self.attach_feature(f)
self.attach_feature(toolbox.ReplaceValidate())
self.attach_feature(ReplaceValidate())
self.inputs = []
for in_var in inputs:
......@@ -311,7 +312,7 @@ class FunctionGraph(utils.MetaObject):
for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i)))
def import_var(self, var, reason):
def import_var(self, var, reason=None, import_missing=False):
"""Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node.
......@@ -322,10 +323,13 @@ class FunctionGraph(utils.MetaObject):
The variable to be imported.
reason : str
The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
"""
# Imports the owners of the variables
if var.owner and var.owner not in self.apply_nodes:
self.import_node(var.owner, reason=reason)
self.import_node(var.owner, reason=reason, import_missing=import_missing)
elif (
var.owner is None
and not isinstance(var, Constant)
......@@ -335,13 +339,16 @@ class FunctionGraph(utils.MetaObject):
if isinstance(var.type, NullType):
raise TypeError(
"Computation graph contains a NaN. " + var.type.why_null
f"Computation graph contains a NaN. {var.type.why_null}"
)
raise MissingInputError("Undeclared input", variable=var)
if import_missing:
self.add_input(var)
else:
raise MissingInputError(f"Undeclared input: {var}", variable=var)
self.setup_var(var)
self.variables.add(var)
def import_node(self, apply_node, check=True, reason=None):
def import_node(self, apply_node, check=True, reason=None, import_missing=False):
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters:
......@@ -353,13 +360,13 @@ class FunctionGraph(utils.MetaObject):
the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
"""
node = apply_node
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
# We import the nodes in topological order. We only are interested in
# new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs)
if check:
......@@ -370,15 +377,18 @@ class FunctionGraph(utils.MetaObject):
and not isinstance(var, Constant)
and var not in self.inputs
):
# Standard error message
error_msg = (
f"Input {int(node.inputs.index(var))} of the graph (indices start "
f"from 0), used to compute {node}, was not "
"provided and not given a value. Use the "
"Aesara flag exception_verbosity='high', "
"for more information on this error."
)
raise MissingInputError(error_msg, variable=var)
if import_missing:
self.add_input(var)
else:
error_msg = (
f"Input {node.inputs.index(var)} ({var})"
" of the graph (indices start "
f"from 0), used to compute {node}, was not "
"provided and not given a value. Use the "
"Aesara flag exception_verbosity='high', "
"for more information on this error."
)
raise MissingInputError(error_msg, variable=var)
for node in new_nodes:
assert node not in self.apply_nodes
......@@ -397,7 +407,7 @@ class FunctionGraph(utils.MetaObject):
self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason)
def change_input(self, node, i, new_var, reason=None):
def change_input(self, node, i, new_var, reason=None, import_missing=False):
"""Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
......@@ -416,7 +426,8 @@ class FunctionGraph(utils.MetaObject):
The index in `node.inputs` that we want to change.
new_var : aesara.graph.basic.Variable
The new variable to take the place of ``node.inputs[i]``.
import_missing : bool
Add missing inputs instead of raising an exception.
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output":
......@@ -443,15 +454,15 @@ class FunctionGraph(utils.MetaObject):
if r is new_var:
return
self.import_var(new_var, reason=reason)
self.import_var(new_var, reason=reason, import_missing=import_missing)
self.add_client(new_var, (node, i))
self.remove_client(r, (node, i), reason=reason)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
# Precondition: the substitution is semantically valid However it may
# introduce cycles to the graph, in which case the transaction will be
# reverted later.
self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason)
def replace(self, var, new_var, reason=None, verbose=None):
def replace(self, var, new_var, reason=None, verbose=None, import_missing=False):
"""Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
......@@ -467,6 +478,8 @@ class FunctionGraph(utils.MetaObject):
The name of the optimization or operation in progress.
verbose : bool
Print `reason`, `var`, and `new_var`.
import_missing : bool
Import missing variables.
"""
if verbose is None:
......@@ -477,10 +490,16 @@ class FunctionGraph(utils.MetaObject):
new_var = var.type.filter_variable(new_var, allow_convert=True)
if var not in self.variables:
# TODO: Raise an actual exception here.
# Old comment:
# this variable isn't in the graph... don't raise an
# exception here, just return silently because it makes it
# easier to implement some optimizations for
# multiple-output ops
# raise ValueError()
warnings.warn(
f"Variable {var} cannot be replaced; it isn't in the FunctionGraph"
)
return
if config.compute_test_value != "off":
......@@ -503,12 +522,14 @@ class FunctionGraph(utils.MetaObject):
assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var
)
self.change_input(node, i, new_var, reason=reason)
self.change_input(
node, i, new_var, reason=reason, import_missing=import_missing
)
def replace_all(self, pairs, reason=None):
def replace_all(self, pairs, **kwargs):
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, reason=reason)
self.replace(var, new_var, **kwargs)
def attach_feature(self, feature):
"""
......@@ -516,25 +537,25 @@ class FunctionGraph(utils.MetaObject):
on_attach callback.
"""
# Filter out literally identical features
# Filter out literally identical `Feature`s
if feature in self._features:
return # the feature is already present
# Filter out functionally identical features.
# Features may use their on_attach method to raise
# toolbox.AlreadyThere if they detect that some
# installed feature does the same thing already
# Filter out functionally identical `Feature`s.
# `Feature`s may use their `on_attach` method to raise
# `AlreadyThere` if they detect that some
# installed `Feature` does the same thing already
attach = getattr(feature, "on_attach", None)
if attach is not None:
try:
attach(self)
except toolbox.AlreadyThere:
except AlreadyThere:
return
self.execute_callbacks_times.setdefault(feature, 0)
# it would be nice if we could require a specific class instead of
# It would be nice if we could require a specific class instead of
# a "workalike" so we could do actual error checking
# if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected graph.toolbox.Feature instance, got "+\
# if not isinstance(feature, Feature):
# raise TypeError("Expected Feature instance, got "+\
# str(type(feature)))
# Add the feature
......
......@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer):
# Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them.
if all([isinstance(old, Constant) for old, new in pairs]):
fgraph.replace_all(pairs, "MergeOptimizer")
fgraph.replace_all(pairs, reason="MergeOptimizer")
else:
fgraph.replace_all_validate(pairs, "MergeOptimizer")
fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
except InconsistencyError:
success = False
nb_fail += 1
......
......@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator):
del fgraph.replace_all_validate
del fgraph.replace_all_validate_remove
def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs)
def replace_all_validate(self, fgraph, replacements, reason=None, verbose=None):
def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
......@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator):
for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason, verbose=False)
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
except Exception as e:
msg = str(e)
s1 = "The type of the replacement must be the same"
......@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator):
return chk
def replace_all_validate_remove(
self, fgraph, replacements, remove, reason=None, warn=True
self, fgraph, replacements, remove, reason=None, warn=True, **kwargs
):
"""
As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. Also print a warning.
"""
chk = fgraph.replace_all_validate(replacements, reason)
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
self._nodes_removed.update(remove)
for rm in remove:
if rm in fgraph.apply_nodes or rm in fgraph.variables:
......
......@@ -111,20 +111,26 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var5 = MyVariable("var5")
var6 = op2(var5)
var8 = MyVariable("var8")
var6 = op2(var8)
with pytest.raises(MissingInputError):
fg.import_node(var6.owner)
var6 = op2(var2)
assert not hasattr(var6.owner.tag, "imported_by")
fg.import_node(var6.owner)
assert var8 not in fg.variables
assert hasattr(var6.owner.tag, "imported_by")
assert var6 in fg.variables
fg.import_node(var6.owner, import_missing=True)
assert var8 in fg.inputs
assert var6.owner in fg.apply_nodes
assert (var6.owner, 0) in fg.get_clients(var2)
var7 = op2(var2)
assert not hasattr(var7.owner.tag, "imported_by")
fg.import_node(var7.owner)
assert hasattr(var7.owner.tag, "imported_by")
assert var7 in fg.variables
assert var7.owner in fg.apply_nodes
assert (var7.owner, 0) in fg.get_clients(var2)
def test_import_var(self):
......@@ -135,12 +141,17 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var0 = MyVariable("var0")
with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
# We can't import a new `FunctionGraph` input (i.e. something
# without an owner)
# without an owner), at least not without setting `import_missing`
fg.import_var(var0, "testing")
fg.import_var(var0, import_missing=True)
assert var0 in fg.inputs
var5 = op2()
# We can import variables with owners
fg.import_var(var5, "testing")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论