提交 8d71eb32 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

renamed Env to FunctionGraph. closes 719

上级 56888c31
...@@ -36,32 +36,32 @@ Here is an overview of the various steps that are done with the ...@@ -36,32 +36,32 @@ Here is an overview of the various steps that are done with the
computation graph in the compilation phase: computation graph in the compilation phase:
Step 1 - Create an Env Step 1 - Create a FunctionGraph
^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^
The subgraph given by the end user is wrapped in a structure called The subgraph given by the end user is wrapped in a structure called
*Env*. That structure defines several hooks on adding and *FunctionGraph*. That structure defines several hooks on adding and
removing (pruning) nodes as well as on modifying links between nodes removing (pruning) nodes as well as on modifying links between nodes
(for example, modifying an input of an :ref:`apply` node) (see the (for example, modifying an input of an :ref:`apply` node) (see the
article about :ref:`libdoc_gof_env` for more information). article about :ref:`libdoc_gof_env` for more information).
Env provides a method to change the input of an Apply node from one FunctionGraph provides a method to change the input of an Apply node from one
Variable to another and a more high-level method to replace a Variable Variable to another and a more high-level method to replace a Variable
with another. This is the structure that :ref:`Optimizers with another. This is the structure that :ref:`Optimizers
<optimization>` work on. <optimization>` work on.
Some relevant :ref:`Features <libdoc_gof_envfeature>` are typically added to the Some relevant :ref:`Features <libdoc_gof_envfeature>` are typically added to the
Env, namely to prevent any optimization from operating inplace on FunctionGraph, namely to prevent any optimization from operating inplace on
inputs declared as immutable. inputs declared as immutable.
Step 2 - Execute main Optimizer Step 2 - Execute main Optimizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once the Env is made, an :term:`optimizer` is produced Once the FunctionGraph is made, an :term:`optimizer` is produced
by the :term:`mode` passed to ``function`` or to the Method/Module's by the :term:`mode` passed to ``function`` or to the Method/Module's
``make`` (the Mode basically has two important fields, ``linker`` and ``make`` (the Mode basically has two important fields, ``linker`` and
``optimizer``). That optimizer is applied on the Env using its ``optimizer``). That optimizer is applied on the FunctionGraph using its
optimize() method. optimize() method.
The optimizer is typically obtained through :attr:`optdb`. The optimizer is typically obtained through :attr:`optdb`.
...@@ -71,7 +71,8 @@ Step 3 - Execute linker to obtain a thunk ...@@ -71,7 +71,8 @@ Step 3 - Execute linker to obtain a thunk
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once the computation graph is optimized, the :term:`linker` is Once the computation graph is optimized, the :term:`linker` is
extracted from the Mode. It is then called with the Env as argument to extracted from the Mode. It is then called with the FunctionGraph as
argument to
produce a ``thunk``, which is a function with no arguments that produce a ``thunk``, which is a function with no arguments that
returns nothing. Along with the thunk, one list of input containers (a returns nothing. Along with the thunk, one list of input containers (a
theano.gof.Container is a sort of object that wraps another and does theano.gof.Container is a sort of object that wraps another and does
......
...@@ -13,7 +13,7 @@ import numpy ...@@ -13,7 +13,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof import Env, graph, utils, link, ops_with_inner_function from theano.gof import FunctionGraph as Env, graph, utils, link, ops_with_inner_function
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.gof.cc import CLinker from theano.gof.cc import CLinker
from theano.gof.python25 import all, any, product as itertools_product from theano.gof.python25 import all, any, product as itertools_product
...@@ -667,7 +667,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace=False): ...@@ -667,7 +667,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace=False):
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
env = gof.env.Env(inputs, outputs, env = gof.fg.FunctionGraph(inputs, outputs,
#DestroyHandler is not needed because it is actually installed by an optimization #DestroyHandler is not needed because it is actually installed by an optimization
# after canonicalization. This variables in a big speed gain. # after canonicalization. This variables in a big speed gain.
#features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)]) #features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)])
......
...@@ -129,7 +129,7 @@ def std_env(input_specs, output_specs, accept_inplace = False): ...@@ -129,7 +129,7 @@ def std_env(input_specs, output_specs, accept_inplace = False):
orig_outputs = [spec.variable for spec in output_specs] + updates orig_outputs = [spec.variable for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
env = gof.env.Env(inputs, outputs) env = gof.fg.FunctionGraph(inputs, outputs)
for node in env.nodes: for node in env.nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
......
...@@ -5,8 +5,10 @@ from cc import \ ...@@ -5,8 +5,10 @@ from cc import \
import compiledir # adds config vars import compiledir # adds config vars
from env import \ from fg import \
InconsistencyError, MissingInputError, Env InconsistencyError, MissingInputError, FunctionGraph
#deprecated alias to support code written with old name
Env = FunctionGraph
from destroyhandler import \ from destroyhandler import \
DestroyHandler DestroyHandler
......
...@@ -10,7 +10,7 @@ import toolbox ...@@ -10,7 +10,7 @@ import toolbox
import graph import graph
from theano.gof.python25 import deque from theano.gof.python25 import deque
from env import InconsistencyError from fg import InconsistencyError
class ProtocolError(Exception): class ProtocolError(Exception):
"""WRITEME""" """WRITEME"""
......
"""WRITEME""" """
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
types that it can raise
"""
import sys import sys
from copy import copy
import graph import graph
import utils import utils
import toolbox import toolbox
...@@ -11,7 +14,7 @@ from theano import config ...@@ -11,7 +14,7 @@ from theano import config
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to Env when the This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid. graph's state is invalid.
""" """
pass pass
...@@ -24,14 +27,18 @@ class MissingInputError(Exception): ...@@ -24,14 +27,18 @@ class MissingInputError(Exception):
class Env(utils.object2): class FunctionGraph(utils.object2):
""" WRITEME """ WRITEME
An Env represents a subgraph bound by a set of input variables and a A FunctionGraph represents a subgraph bound by a set of input variables and a
set of output variables. The inputs list should contain all the inputs set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs
on which the outputs depend. Variables of type Constant are on which the outputs depend. Variables of type Constant are
not counted as inputs. not counted as inputs.
The Env supports the replace operation which allows to replace a Historically, the FunctionGraph was called an Env. Many other objects refer
to the FunctionGraph they belong to as their "env".
The FunctionGraph supports the replace operation which allows to replace a
variable in the subgraph by another, e.g. replace (x + x).out by (2 variable in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in theano. * x).out. This is the basis for optimization in theano.
...@@ -42,34 +49,34 @@ class Env(utils.object2): ...@@ -42,34 +49,34 @@ class Env(utils.object2):
The .clients field combined with the .owner field and the Apply nodes' The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions. .inputs field allows the graph to be traversed in both directions.
It can also be "extended" using env.extend(some_object). See the It can also be "extended" using function_graph.extend(some_object). See the
toolbox and ext modules for common extensions. toolbox and ext modules for common extensions.
Features added with the`extend` function can handle the following events: Features added with the`extend` function can handle the following events:
- feature.on_attach(env) - feature.on_attach(function_graph)
Called by extend. The feature has great freedom in what Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods it can do with the function_graph: it may, for example, add methods
to it dynamically. to it dynamically.
- feature.on_detach(env) - feature.on_detach(function_graph)
Called by remove_feature(feature). Should remove any dynamically-added Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the env. functionality that it installed into the function_graph.
- feature.on_import(env, node)* - feature.on_import(function_graph, node)*
Called whenever a node is imported into env, which is Called whenever a node is imported into function_graph, which is
just before the node is actually connected to the graph. just before the node is actually connected to the graph.
- feature.on_prune(env, node)* - feature.on_prune(function_graph, node)*
Called whenever a node is pruned (removed) from the env, Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph. after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r, [reason=None])* - feature.on_change_input(function_graph, node, i, r, new_r, [reason=None])*
Called whenever node.inputs[i] is changed from r to new_r. Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already At the moment the callback is done, the change has already
taken place. taken place.
- feature.orderings(env) - feature.orderings(function_graph)
Called by toposort. It should return a dictionary of Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of {node: predecessors} where predecessors is a list of
nodes that should be computed before the key node. nodes that should be computed before the key node.
...@@ -77,10 +84,10 @@ class Env(utils.object2): ...@@ -77,10 +84,10 @@ class Env(utils.object2):
* If you raise an exception in the functions marked with an * If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent. asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(env, node): - feature.on_setup_node(function_graph, node):
WRITEME WRITEME
- feature.on_setup_variable(env, variable): - feature.on_setup_variable(function_graph, variable):
WRITEME WRITEME
""" """
...@@ -90,14 +97,14 @@ class Env(utils.object2): ...@@ -90,14 +97,14 @@ class Env(utils.object2):
def __init__(self, inputs, outputs, features=None): def __init__(self, inputs, outputs, features=None):
""" """
Create an Env which operates on the subgraph bound by the inputs and Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets. outputs sets.
This class keeps a pointer to the inputs and outputs, and also modifies This class keeps a pointer to the inputs and outputs, and also modifies
them. them.
#TODO: document what variables are[not] set in the env when a feature #TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the env? is added via the constructor. How constructed is the FunctionGraph?
""" """
...@@ -171,12 +178,12 @@ class Env(utils.object2): ...@@ -171,12 +178,12 @@ class Env(utils.object2):
def disown(self): def disown(self):
""" WRITEME """ WRITEME
Cleans up all of this Env's nodes and variables so they are not Cleans up all of this FunctionGraph's nodes and variables so they are not
associated with this Env anymore. associated with this FunctionGraph anymore.
The Env should not be used anymore after disown is called. The FunctionGraph should not be used anymore after disown is called.
This may not clean everything this Env's features set in the This may not clean everything this FunctionGraph's features set in the
nodes and variables. If there are no features, this should set nodes and variables. If there are no features, this should set
them back to what they were originally. them back to what they were originally.
""" """
...@@ -362,7 +369,7 @@ class Env(utils.object2): ...@@ -362,7 +369,7 @@ class Env(utils.object2):
def __prune__(self, node): def __prune__(self, node):
if node not in self.nodes: if node not in self.nodes:
raise Exception("%s does not belong to this Env and cannot be pruned." % node) raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
assert node.env is self assert node.env is self
# If node's outputs have no clients, removes it from the graph # If node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one # and recursively tries to prune its inputs. If at least one
...@@ -392,7 +399,7 @@ class Env(utils.object2): ...@@ -392,7 +399,7 @@ class Env(utils.object2):
current value of node.inputs[i] which we want to replace. current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls: For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r, [reason]) feature.on_change_input(function_graph, node, i, old_r, new_r, [reason])
""" """
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output': if node == 'output':
...@@ -405,7 +412,7 @@ class Env(utils.object2): ...@@ -405,7 +412,7 @@ class Env(utils.object2):
else: else:
if node.env is not self: if node.env is not self:
raise Exception("Cannot operate on %s because it does not" raise Exception("Cannot operate on %s because it does not"
" belong to this Env" % node) " belong to this FunctionGraph" % node)
r = node.inputs[i] r = node.inputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the" raise TypeError("The type of the replacement must be the"
...@@ -432,11 +439,11 @@ class Env(utils.object2): ...@@ -432,11 +439,11 @@ class Env(utils.object2):
def replace(self, r, new_r, reason=None): def replace(self, r, new_r, reason=None):
""" WRITEME """ WRITEME
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
""" """
if r.env is not self: if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r, str(reason)) raise Exception("Cannot replace %s because it does not belong to this FunctionGraph" % r, str(reason))
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r, r.type, new_r.type, str(reason)) raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r, r.type, new_r.type, str(reason))
if r not in self.variables: if r not in self.variables:
...@@ -466,7 +473,7 @@ class Env(utils.object2): ...@@ -466,7 +473,7 @@ class Env(utils.object2):
# would expect it to do similarly. # would expect it to do similarly.
def extend(self, feature): def extend(self, feature):
"""WRITEME """WRITEME
Adds a feature to this env. The feature may define one Adds a feature to this function_graph. The feature may define one
or more of the following methods: or more of the following methods:
""" """
...@@ -484,7 +491,7 @@ class Env(utils.object2): ...@@ -484,7 +491,7 @@ class Env(utils.object2):
"""WRITEME """WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined. Calls feature.on_detach(function_graph) if an on_detach method is defined.
""" """
try: try:
self._features.remove(feature) self._features.remove(feature)
...@@ -545,7 +552,7 @@ class Env(utils.object2): ...@@ -545,7 +552,7 @@ class Env(utils.object2):
an 'orderings' method. an 'orderings' method.
If a feature has an 'orderings' method, it will be called with If a feature has an 'orderings' method, it will be called with
this env as sole argument. It should return a dictionary of this FunctionGraph as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes {node: predecessors} where predecessors is a list of nodes
that should be computed before the key node. that should be computed before the key node.
""" """
...@@ -555,9 +562,9 @@ class Env(utils.object2): ...@@ -555,9 +562,9 @@ class Env(utils.object2):
# This special case happens a lot because the OpWiseCLinker produces # This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs. # 1-element graphs.
return list(self.nodes) return list(self.nodes)
env = self fg = self
ords = self.orderings() ords = self.orderings()
order = graph.io_toposort(env.inputs, env.outputs, ords) order = graph.io_toposort(fg.inputs, fg.outputs, ords)
return order return order
def orderings(self): def orderings(self):
...@@ -605,10 +612,10 @@ class Env(utils.object2): ...@@ -605,10 +612,10 @@ class Env(utils.object2):
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess) raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
for node in nodes: for node in nodes:
if node.env is not self: if node.env is not self:
raise Exception("Node should belong to the env.", node) raise Exception("Node should belong to the FunctionGraph.", node)
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
if variable.env is not self: if variable.env is not self:
raise Exception("Input of node should belong to the env.", variable, (node, i)) raise Exception("Input of node should belong to the FunctionGraph.", variable, (node, i))
if (node, i) not in variable.clients: if (node, i) not in variable.clients:
raise Exception("Inconsistent clients list.", (node, i), variable.clients) raise Exception("Inconsistent clients list.", (node, i), variable.clients)
variables = set(graph.variables(self.inputs, self.outputs)) variables = set(graph.variables(self.inputs, self.outputs))
...@@ -620,14 +627,14 @@ class Env(utils.object2): ...@@ -620,14 +627,14 @@ class Env(utils.object2):
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant): if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.env is not self: if variable.env is not self:
raise Exception("Variable should belong to the env.", variable) raise Exception("Variable should belong to the FunctionGraph.", variable)
for node, i in variable.clients: for node, i in variable.clients:
if node == 'output': if node == 'output':
if self.outputs[i] is not variable: if self.outputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, self.outputs[i]) raise Exception("Inconsistent clients list.", variable, self.outputs[i])
continue continue
if node not in nodes: if node not in nodes:
raise Exception("Client not in env.", variable, (node, i)) raise Exception("Client not in FunctionGraph.", variable, (node, i))
if node.inputs[i] is not variable: if node.inputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, node.inputs[i]) raise Exception("Inconsistent clients list.", variable, node.inputs[i])
...@@ -648,7 +655,7 @@ class Env(utils.object2): ...@@ -648,7 +655,7 @@ class Env(utils.object2):
"""WRITEME""" """WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity() self.check_integrity()
e = Env([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs])
e.check_integrity() e.check_integrity()
for feature in self._features: for feature in self._features:
......
...@@ -21,7 +21,7 @@ from theano import config ...@@ -21,7 +21,7 @@ from theano import config
import cc import cc
import graph import graph
import utils import utils
from env import Env from fg import FunctionGraph as Env
class CLinkerObject(object): class CLinkerObject(object):
......
...@@ -11,7 +11,7 @@ import time ...@@ -11,7 +11,7 @@ import time
import numpy import numpy
import graph import graph
from env import InconsistencyError from fg import InconsistencyError
import op import op
import utils import utils
import unify import unify
...@@ -598,7 +598,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -598,7 +598,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
givens = copied[2] givens = copied[2]
# Create Env. # Create Env.
inputs = theano.gof.graph.inputs(vars) inputs = theano.gof.graph.inputs(vars)
env = theano.gof.env.Env(inputs, vars) env = theano.gof.fg.FunctionGraph(inputs, vars)
# Perform Variable substitution. # Perform Variable substitution.
for to_replace, replace_by in givens.iteritems(): for to_replace, replace_by in givens.iteritems():
env.replace(to_replace, replace_by) env.replace(to_replace, replace_by)
......
...@@ -6,7 +6,8 @@ from theano.gof.cc import * ...@@ -6,7 +6,8 @@ from theano.gof.cc import *
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof import env from theano.gof import fg
env = fg
from theano.gof import toolbox from theano.gof import toolbox
...@@ -171,7 +172,7 @@ def inputs(): ...@@ -171,7 +172,7 @@ def inputs():
def Env(inputs, outputs): def Env(inputs, outputs):
e = env.Env(inputs, outputs) e = fg.FunctionGraph(inputs, outputs)
return e return e
......
...@@ -8,7 +8,7 @@ from theano.gof.op import Op ...@@ -8,7 +8,7 @@ from theano.gof.op import Op
from theano.gof.opt import * from theano.gof.opt import *
from theano.gof import destroyhandler from theano.gof import destroyhandler
from theano.gof.env import Env, InconsistencyError from theano.gof.fg import FunctionGraph as Env, InconsistencyError
from theano.gof.toolbox import ReplaceValidate from theano.gof.toolbox import ReplaceValidate
from copy import copy from copy import copy
......
...@@ -3,7 +3,7 @@ from theano.gof.type import Type ...@@ -3,7 +3,7 @@ from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import * from theano.gof.opt import *
from theano.gof.env import Env from theano.gof.fg import FunctionGraph as Env
from theano.gof.toolbox import * from theano.gof.toolbox import *
......
...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply ...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.env import Env, InconsistencyError from theano.gof.fg import FunctionGraph as Env, InconsistencyError
from theano.gof.toolbox import * from theano.gof.toolbox import *
...@@ -85,4 +85,4 @@ class TestNodeFinder: ...@@ -85,4 +85,4 @@ class TestNodeFinder:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
...@@ -3,7 +3,7 @@ from theano.gof.type import Type ...@@ -3,7 +3,7 @@ from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import * from theano.gof.opt import *
from theano.gof.env import Env from theano.gof.fg import FunctionGraph as Env
from theano.gof.toolbox import * from theano.gof.toolbox import *
import theano.tensor.basic as T import theano.tensor.basic as T
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论