提交 d8a82f73 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Refactor FunctionGraph interface

This commit does the following: - changes `r` to `var`, - adds missing docstrings, - and removes unnecessary dunder method names.
上级 3c47f74a
......@@ -15,9 +15,6 @@ from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet
NullType = None
class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
......@@ -105,14 +102,16 @@ class FunctionGraph(utils.object2):
Parameters
----------
inputs : list of variables
inputs : list of theano.gof.graph.Variable
Inputs nodes of the graph, usually declared by the user
outputs : list of variables
outputs : list of theano.gof.graph.Variable
Outputs nodes of the graph.
clone : boolean
If true, we will clone the graph. This is useful to remove the
constant cache problem.
update_mapping : dictionary
features : list of theano.gof.toolbox.Feature
A list of features to be added to the `FunctionGraph`.
update_mapping : dict
Mapping between the inputs with updates and the outputs
corresponding to their updates.
"""
......@@ -120,6 +119,12 @@ class FunctionGraph(utils.object2):
if clone:
inputs, outputs = graph.clone(inputs, outputs)
if not isinstance(inputs, list):
raise TypeError("Argument `inputs` should be a list")
if not isinstance(outputs, list):
raise TypeError("Argument `outputs` should be a list")
self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
......@@ -139,47 +144,71 @@ class FunctionGraph(utils.object2):
# outputs even if they aren't used in the graph.
self.variables = set()
self.inputs = list(inputs)
# TODO FIXME: We should *not* be using a list created elsewhere!
self.outputs = outputs
for f in features:
self.attach_feature(f)
self.attach_feature(toolbox.ReplaceValidate())
for input in self.inputs:
if input.owner is not None:
self.inputs = []
for in_var in inputs:
if in_var.owner is not None:
raise ValueError(
"One of the provided inputs is the output of"
"One of the provided inputs is the output of "
"an already existing node. "
"If that is okay, either discard that "
"input's owner or use graph.clone."
)
self.__setup_r__(input)
self.variables.add(input)
self.add_input(in_var, check=False)
for output in outputs:
self.__import_r__(output, reason="init")
self.import_var(output, reason="init")
for i, output in enumerate(outputs):
output.clients.append(("output", i))
self.profile = None
self.update_mapping = update_mapping
def add_input(self, input):
if input not in self.inputs:
self.inputs.append(input)
self.__setup_r__(input)
self.variables.add(input)
def add_input(self, var, check=True):
"""Add a new variable as an input to this `FunctionGraph`.
Parameters
----------
var : theano.gof.graph.Variable
"""
if check and var in self.inputs:
return
self.inputs.append(var)
self.setup_var(var)
self.variables.add(var)
def setup_var(self, var):
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : theano.gof.graph.Variable
def __setup_r__(self, r):
if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
# self.execute_callbacks('on_setup_variable', r)
"""
if hasattr(var, "fgraph") and var.fgraph is not None and var.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % var)
var.fgraph = self
var.clients = []
# self.execute_callbacks('on_setup_variable', var)
def setup_node(self, node):
"""Set up node so it belongs to this `FunctionGraph`.
Parameters
----------
node : theano.gof.graph.Apply
def __setup_node__(self, node):
# sets up node so it belongs to this fgraph
"""
if hasattr(node, "fgraph") and node.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % node)
if hasattr(node.op, "view_map") and not all(
......@@ -226,125 +255,141 @@ class FunctionGraph(utils.object2):
self.profile = None
self.update_mapping = None
# clients #
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Told differently, a list of (node,i) such that each node have
r as input at index i.
def clients(self, var):
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`.
"""
return r.clients
Told differently, a `list` of `(node, i)` such that each node have
`var` as input at index `i`.
def __add_client__(self, r, new_client):
"""
Updates the list of clients of r with new_clients.
return var.clients
def add_client(self, var, new_client):
"""Update the clients of `var` with `new_clients`.
Parameters
----------
r
Variable.
new_client
(node, i) pair such that node.inputs[i] is r.
var : Variable.
new_client : (Apply, int)
A `(node, i)` pair such that `node.inputs[i]` is `var`.
"""
# Ne need to do the assert as it is always True. The logic
# that call __add_client__ is valid. When the client list is
# long, the check it time consuming, so we don't enable it by
# default.
# assert not new_client in r.clients
r.clients.append(new_client)
var.clients.append(new_client)
def __remove_client__(self, r, client_to_remove, reason=None):
"""
Removes all from the clients list of r.
def remove_client(self, var, client_to_remove, reason=None):
"""Recursively removes clients of a variable.
This is the main method to remove variable or apply node from
an FunctionGraph.
This is the main method to remove variables or `Apply` nodes from
a `FunctionGraph`.
Remove r from this fgraph if it don't have clients left. If it
have an owner and all the outputs of the owner have no
clients, it will be removed.
This will remove `var` from the `FunctionGraph` if it doesn't have any
clients remaining. If it has an owner and all the outputs of the owner
have no clients, it will also be removed.
Parameters
----------
r : Variable
The clients of r will be removed.
client_to_remove : (op, i) pair
(op, i) pair such that node.inputs[i] is not r anymore.
"""
l = [(r, client_to_remove)]
while l:
r, client_to_remove = l.pop()
r.clients.remove(client_to_remove)
# entry should be uniq in r. No need to assert it as it is
# already asserted in __add_client__.
# assert entry not in r.clients
if r.clients:
var : Variable
The clients of `var` that will be removed.
client_to_remove : pair of (Apply, int)
A `(node, i)` pair such that `node.inputs[i]` will no longer be
`var` in this `FunctionGraph`.
"""
removal_stack = [(var, client_to_remove)]
while removal_stack:
var, client_to_remove = removal_stack.pop()
try:
var.clients.remove(client_to_remove)
except ValueError:
# In this case, the original `var` could've been removed from
# the current `var`'s client list before this call.
# There's nothing inherently wrong with that, so we continue as
# if it were removed here.
pass
if var.clients:
continue
# r have no more clients, so check if we need to remove it
# and its parent.
variable = r
if not variable.owner:
# A Constant or input without client. Remove it.
self.variables.remove(variable)
# This allow to quickly know if a var is still in the fgraph
# or not.
del variable.fgraph
# Now, `var` has no more clients, so check if we need to remove it
# and its `Apply` node
if not var.owner:
# The `var` is a `Constant` or an input without a client, so we
# remove it
self.variables.remove(var)
# This allows us to quickly determine if `var` is still in the
# `FunctionGraph`
# TODO: It's a poor approach; remove it
del var.fgraph
else:
apply_node = variable.owner
used = [output for output in apply_node.outputs if output.clients]
# If the apply node is not used and is not an output
if not used:
apply_node = var.owner
if not any(output.clients for output in apply_node.outputs):
# The `Apply` node is not used and is not an output, so we
# remove it and its outputs
if not hasattr(apply_node.tag, "removed_by"):
apply_node.tag.removed_by = []
apply_node.tag.removed_by.append(str(reason))
self.apply_nodes.remove(apply_node)
# del apply_node.fgraph
self.variables.difference_update(apply_node.outputs)
#
# for var in apply_node.outputs:
# del var.fgraph
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks("on_prune", apply_node, reason)
for i, input in enumerate(apply_node.inputs):
l.append((input, (apply_node, i)))
for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i)))
def __import_r__(self, variable, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
def import_var(self, var, reason):
"""Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node.
Parameters:
----------
reason
reason is the name of the optimization or operation in progress.
variable : theano.gof.graph.Variable
The variable to be imported.
reason : str
The name of the optimization or operation in progress.
"""
# Imports the owners of the variables
if variable.owner and variable.owner not in self.apply_nodes:
self.__import__(variable.owner, reason=reason)
if var.owner and var.owner not in self.apply_nodes:
self.import_node(var.owner, reason=reason)
elif (
variable.owner is None
and not isinstance(variable, graph.Constant)
and variable not in self.inputs
var.owner is None
and not isinstance(var, graph.Constant)
and var not in self.inputs
):
global NullType
if NullType is None:
from .null_type import NullType
if isinstance(variable.type, NullType):
from theano.gof.null_type import NullType
if isinstance(var.type, NullType):
raise TypeError(
"Computation graph contains a NaN. " + variable.type.why_null
"Computation graph contains a NaN. " + var.type.why_null
)
raise MissingInputError("Undeclared input", variable=variable)
if not getattr(variable, "fgraph", None) is self:
self.__setup_r__(variable)
self.variables.add(variable)
raise MissingInputError("Undeclared input", variable=var)
if not getattr(var, "fgraph", None) is self:
self.setup_var(var)
self.variables.add(var)
def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
def import_node(self, apply_node, check=True, reason=None):
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters:
----------
apply_node : theano.gof.graph.Apply
The node to be imported.
check : bool
Check that the inputs for the imported nodes are also present in
the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
"""
node = apply_node
......@@ -358,13 +403,13 @@ class FunctionGraph(utils.object2):
for node in new_nodes:
if hasattr(node, "fgraph") and node.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % node)
for r in node.inputs:
if hasattr(r, "fgraph") and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
for var in node.inputs:
if hasattr(var, "fgraph") and var.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % var)
if (
r.owner is None
and not isinstance(r, graph.Constant)
and r not in self.inputs
var.owner is None
and not isinstance(var, graph.Constant)
and var not in self.inputs
):
# Standard error message
error_msg = (
......@@ -373,51 +418,60 @@ class FunctionGraph(utils.object2):
"provided and not given a value. Use the "
"Theano flag exception_verbosity='high', "
"for more information on this error."
% (node.inputs.index(r), str(node))
% (node.inputs.index(var), str(node))
)
raise MissingInputError(error_msg, variable=r)
raise MissingInputError(error_msg, variable=var)
for node in new_nodes:
assert node not in self.apply_nodes
self.__setup_node__(node)
self.setup_node(node)
self.apply_nodes.add(node)
if not hasattr(node.tag, "imported_by"):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
for output in node.outputs:
self.__setup_r__(output)
self.setup_var(output)
self.variables.add(output)
for i, input in enumerate(node.inputs):
if input not in self.variables:
self.__setup_r__(input)
self.setup_var(input)
self.variables.add(input)
self.__add_client__(input, (node, i))
self.add_client(input, (node, i))
assert node.fgraph is self
self.execute_callbacks("on_import", node, reason)
# change input #
def change_input(self, node, i, new_r, reason=None):
"""
Changes node.inputs[i] to new_r.
def change_input(self, node, i, new_var, reason=None):
"""Change ``node.inputs[i]`` to `new_var`.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
current value of ``node.inputs[i]`` which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
For each feature that has an `on_change_input` method, this method calls:
``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)``
Parameters
----------
node : theano.gof.graph.Apply or str
The node for which an input is to be changed. If the value is
the string ``"output"`` then the ``self.outputs`` will be used
instead of ``node.inputs``.
i : int
The index in `node.inputs` that we want to change.
new_var : theano.gof.graph.Variable
The new variable to take the place of ``node.inputs[i]``.
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output":
r = self.outputs[i]
if not r.type == new_r.type:
if not r.type == new_var.type:
raise TypeError(
"The type of the replacement must be the"
" same as the type of the original Variable.",
r,
new_r,
new_var,
)
self.outputs[i] = new_r
self.outputs[i] = new_var
else:
if node.fgraph is not self:
raise Exception(
......@@ -425,51 +479,63 @@ class FunctionGraph(utils.object2):
" belong to this FunctionGraph" % node
)
r = node.inputs[i]
if not r.type == new_r.type:
if not r.type == new_var.type:
raise TypeError(
"The type of the replacement must be the"
" same as the type of the original Variable.",
r,
new_r,
new_var,
)
node.inputs[i] = new_r
node.inputs[i] = new_var
if r is new_r:
if r is new_var:
return
self.__import_r__(new_r, reason=reason)
self.__add_client__(new_r, (node, i))
self.__remove_client__(r, (node, i), reason=reason)
self.import_var(new_var, reason=reason)
self.add_client(new_var, (node, i))
self.remove_client(r, (node, i), reason=reason)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self.execute_callbacks("on_change_input", node, i, r, new_r, reason=reason)
self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason)
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
"""
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
def replace(self, var, new_var, reason=None, verbose=None):
"""Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
For every node that uses `var` as input, makes it use `new_var` instead.
Parameters:
----------
var : theano.gof.graph.Variable
The variable to be replaced.
new_var : theano.gof.graph.Variable
The variable to replace `var`.
reason : str
The name of the optimization or operation in progress.
verbose : bool
Print `reason`, `var`, and `new_var`.
"""
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print(reason, r, new_r)
if hasattr(r, "fgraph") and r.fgraph is not self:
print(reason, var, new_var)
if hasattr(var, "fgraph") and var.fgraph is not self:
raise Exception(
"Cannot replace %s because it does not belong "
"to this FunctionGraph" % r,
"to this FunctionGraph" % var,
str(reason),
)
if r.type != new_r.type:
new_r2 = r.type.convert_variable(new_r)
if var.type != new_var.type:
new_var_2 = var.type.convert_variable(new_var)
# We still make sure that the type converts correctly
if new_r2 is None or new_r2.type != r.type:
if new_var_2 is None or new_var_2.type != var.type:
done = dict()
used_ids = dict()
old = theano.compile.debugmode.debugprint(
r,
var,
prefix=" ",
depth=6,
file=StringIO(),
......@@ -478,7 +544,7 @@ class FunctionGraph(utils.object2):
used_ids=used_ids,
).getvalue()
new = theano.compile.debugmode.debugprint(
new_r,
new_var,
prefix=" ",
depth=6,
file=StringIO(),
......@@ -487,16 +553,17 @@ class FunctionGraph(utils.object2):
used_ids=used_ids,
).getvalue()
raise toolbox.BadOptimization(
r,
new_r,
var,
new_var,
None,
None,
str(reason) + ". The type of the replacement must be the same.",
old,
new,
)
new_r = new_r2
if r not in self.variables:
new_var = new_var_2
if var not in self.variables:
# this variable isn't in the graph... don't raise an
# exception here, just return silently because it makes it
# easier to implement some optimizations for
......@@ -505,8 +572,8 @@ class FunctionGraph(utils.object2):
if theano.config.compute_test_value != "off":
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
tval = theano.gof.op.get_test_value(var)
new_tval = theano.gof.op.get_test_value(new_var)
except TestValueError:
pass
else:
......@@ -518,27 +585,21 @@ class FunctionGraph(utils.object2):
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval_shape, new_tval_shape),
r,
new_r,
var,
new_var,
str(reason),
)
for node, i in list(r.clients): # copy the client list for iteration
assert (node == "output" and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
# sometimes the following is triggered. If you understand why, please explain to James.
# He's curious... -JB20090331
# if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
for node, i in list(var.clients): # copy the client list for iteration
assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var
)
self.change_input(node, i, new_var, reason=reason)
def replace_all(self, pairs, reason=None):
"""
For every node that uses r as input, makes it use new_r instead
"""
for r, new_r in pairs:
self.replace(r, new_r, reason=reason)
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, reason=reason)
def attach_feature(self, feature):
"""
......@@ -587,7 +648,6 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""Execute callbacks
......@@ -625,7 +685,6 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
# misc #
def toposort(self):
"""Toposort
......@@ -655,17 +714,16 @@ class FunctionGraph(utils.object2):
return order
def orderings(self):
"""
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
"""Return `dict` `d` s.t. `d[node]` is a list of nodes that must be evaluated before `node` itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their outputs.
the clients of any destroyed inputs have already computed their
outputs.
Notes
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
This only calls the `orderings()` function on all features. It does not
take care of computing the dependencies by itself.
"""
assert isinstance(self._features, list)
......@@ -769,7 +827,6 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
# clone #
def clone(self, check_integrity=True):
"""
Clone the graph and get a memo( a dict )that map old node to new node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论