提交 9f6d9f90 authored 作者: nouiz's avatar nouiz

Merge pull request #882 from goodfeli/rename_extend

renamed FunctionGraph.extend to FunctionGraph.attach_feature
......@@ -210,7 +210,7 @@ if it introduce cycle into the graph.
To allow using DebugMode more often, we can pre-check that our optimization will
get rejected in many case.(not the cycle reason). For this you can use the
theano.gof.destroyhandler.fast_inplace_check() function that will tell you witch
theano.gof.destroyhandler.fast_inplace_check() function that will tell you which
op can be used.
......
......@@ -124,7 +124,7 @@ simplification described above:
class Simplify(gof.Optimizer):
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
for node in fgraph.toposort():
if node.op == div:
......
......@@ -681,12 +681,12 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
node)
# We need to protect all immutable inputs from inplace operations.
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs)
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, inputs)
if not (spec.mutable or (hasattr(fgraph, 'destroyers')
and fgraph.destroyers(input)))))
for feature in std_fgraph.features:
fgraph.extend(feature())
fgraph.attach_feature(feature())
return fgraph, map(SymbolicOutput, updates), equivalence_tracker
......
......@@ -136,15 +136,15 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op)
else:
fgraph.extend(gof.DestroyHandler())
fgraph.attach_feature(gof.DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations.
fgraph.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name
for feature in std_fgraph.features:
fgraph.extend(feature())
fgraph.attach_feature(feature())
return fgraph, map(SymbolicOutput, updates)
......
......@@ -208,7 +208,7 @@ class AddDestroyHandler(gof.Optimizer):
def add_requirements(self, fgraph):
super(AddDestroyHandler, self).add_requirements(fgraph)
fgraph.extend(gof.DestroyHandler())
fgraph.attach_feature(gof.DestroyHandler())
class PrintCurrentFunctionGraph(gof.Optimizer):
......
......@@ -10,6 +10,7 @@ import utils
import toolbox
from python25 import all
from theano import config
import warnings
class InconsistencyError(Exception):
......@@ -46,17 +47,18 @@ class FunctionGraph(utils.object2):
The .clients field combined with the .owner field and the Apply nodes'
.inputs field allows the graph to be traversed in both directions.
It can also be "extended" using function_graph.extend(some_object).
It can also be extended with new features using
FunctionGraph.attach_feature(<toolbox.Feature instance>).
See toolbox.Feature for event types and documentation.
Extra features allow the FunctionGraph to verify new properties of
a graph as it is optimized.
# TODO: are there other things features can do to the fgraph?
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
"""
### Special ###
# TODO: document which things that features can do to the fgraph
def __init__(self, inputs, outputs, features=None):
"""
Create an FunctionGraph which operates on the subgraph bound by the inputs and
......@@ -87,8 +89,8 @@ class FunctionGraph(utils.object2):
self.outputs = outputs
for f in features:
self.extend(f)
self.extend(toolbox.ReplaceValidate())
self.attach_feature(f)
self.attach_feature(toolbox.ReplaceValidate())
for input in self.inputs:
if input.owner is not None:
......@@ -428,12 +430,13 @@ class FunctionGraph(utils.object2):
self.replace(r, new_r, reason=reason)
### features ###
# XXX: This is terribly named. The "extend" method of a list
# takes a sequence, and since this is a kind of container you
# would expect it to do similarly.
def extend(self, feature):
warnings.warn("FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature")
return self.attach_feature(feature)
def attach_feature(self, feature):
"""
Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback
......@@ -631,5 +634,5 @@ class FunctionGraph(utils.object2):
[equiv[o] for o in self.outputs])
e.check_integrity()
for feature in self._features:
e.extend(feature)
e.attach_feature(feature)
return e, equiv
......@@ -87,8 +87,8 @@ class Optimizer(object):
"""WRITEME
Add features to the fgraph that are required to apply the optimization.
For example:
fgraph.extend(History())
fgraph.extend(MyFeature())
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
etc.
"""
pass
......@@ -112,7 +112,7 @@ class FromFunctionOptimizer(Optimizer):
def add_requirements(self, fgraph):
# Added by default
#fgraph.extend(toolbox.ReplaceValidate())
#fgraph.attach_feature(toolbox.ReplaceValidate())
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
......@@ -557,9 +557,9 @@ class MergeOptimizer(Optimizer):
def add_requirements(self, fgraph):
# Added by default
#fgraph.extend(toolbox.ReplaceValidate())
#fgraph.attach_feature(toolbox.ReplaceValidate())
if not hasattr(fgraph, 'merge_feature'):
fgraph.extend(MergeFeature())
fgraph.attach_feature(MergeFeature())
def apply(self, fgraph):
# Constant and non-constant are now applied in the same phase.
......@@ -713,7 +713,7 @@ class LocalOptimizer(object):
This is the place to do it.
"""
# Added by default
#fgraph.extend(toolbox.ReplaceValidate())
#fgraph.attach_feature(toolbox.ReplaceValidate())
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
......@@ -1195,7 +1195,7 @@ class NavigatorOptimizer(Optimizer):
chin(node, i, r, new_r)
u = Updater()
fgraph.extend(u)
fgraph.attach_feature(u)
return u
def detach_updater(self, fgraph, u):
......@@ -1269,7 +1269,7 @@ class NavigatorOptimizer(Optimizer):
def add_requirements(self, fgraph):
super(NavigatorOptimizer, self).add_requirements(fgraph)
# Added by default
#fgraph.extend(toolbox.ReplaceValidate())
#fgraph.attach_feature(toolbox.ReplaceValidate())
if self.local_opt:
self.local_opt.add_requirements(fgraph)
......@@ -1370,7 +1370,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
- ReplaceValidate(Added by default)
"""
super(OpKeyOptimizer, self).add_requirements(fgraph)
fgraph.extend(toolbox.NodeFinder())
fgraph.attach_feature(toolbox.NodeFinder())
class ChangeTracker:
......@@ -1426,7 +1426,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def add_requirements(self, fgraph):
super(EquilibriumOptimizer, self).add_requirements(fgraph)
fgraph.extend(ChangeTracker())
fgraph.attach_feature(ChangeTracker())
for opt in self.local_optimizers:
opt.add_requirements(fgraph)
for opt in self.global_optimizers:
......@@ -1759,7 +1759,7 @@ class InplaceOptimizer(Optimizer):
self.inplace(fgraph)
def add_requirements(self, fgraph):
fgraph.extend(dh.DestroyHandler())
fgraph.attach_feature(dh.DestroyHandler())
class PureThenInplaceOptimizer(Optimizer):
......@@ -1770,5 +1770,5 @@ class PureThenInplaceOptimizer(Optimizer):
def apply(self, fgraph):
self.pure(fgraph)
fgraph.extend(dh.DestroyHandler())
fgraph.attach_feature(dh.DestroyHandler())
self.inplace(fgraph)
......@@ -93,8 +93,8 @@ def inputs():
_Env = Env
def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs)
e.extend(destroyhandler.DestroyHandler())
e.extend(ReplaceValidate())
e.attach_feature(destroyhandler.DestroyHandler())
e.attach_feature(ReplaceValidate())
if validate:
e.validate()
return e
......
......@@ -69,7 +69,8 @@ class TestNodeFinder:
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e])
g.extend(NodeFinder())
g.attach_feature(NodeFinder())
assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num:
......
......@@ -32,7 +32,7 @@ class Feature(object):
def on_attach(self, function_graph):
"""
Called by FunctionGraph.extend, the method that attaches the feature
Called by FunctionGraph.attach_feature, the method that attaches the feature
to the FunctionGraph. Since this is called after the FunctionGraph
is initially populated, this is where you should run checks on the
initial contents of the FunctionGraph.
......
......@@ -537,7 +537,7 @@ def cond_merge_ifs_false(node):
class CondMerge(gof.Optimizer):
""" Graph Optimizer that merges different cond ops """
def add_requirements(self, fgraph):
fgraph.extend(gof.toolbox.ReplaceValidate())
fgraph.add_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = list(fgraph.toposort())
......
......@@ -80,8 +80,8 @@ class InputToGpuOptimizer(Optimizer):
Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.extend(DestroyHandler())
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
for input in fgraph.inputs:
......
......@@ -177,7 +177,7 @@ class HintsOptimizer(Optimizer):
Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.extend(HintsFeature())
fgraph.attach_feature(HintsFeature())
def apply(self, fgraph):
pass
......
......@@ -149,7 +149,7 @@ class PushOutNonSeqScan(gof.Optimizer):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.extend(gof.toolbox.ReplaceValidate())
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
......@@ -322,8 +322,8 @@ class ScanInplaceOptimizer(Optimizer):
self.gpu_flag = gpu_flag
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.extend(DestroyHandler())
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
......@@ -388,7 +388,7 @@ class ScanSaveMem(gof.Optimizer):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.extend(gof.toolbox.ReplaceValidate())
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def process_node(self, fgraph, node):
......@@ -857,7 +857,7 @@ scan_seqopt.register('scanOp_save_mem',
class ScanMerge(gof.Optimizer):
""" Graph Optimizer that merges different scan ops """
def add_requirements(self, fgraph):
fgraph.extend(gof.toolbox.ReplaceValidate())
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def merge(self, nodes):
......
......@@ -1353,8 +1353,8 @@ class GemmOptimizer(Optimizer):
self.warned = False
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.extend(DestroyHandler())
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
did_something = True
......
......@@ -538,7 +538,7 @@ class MakeVector(T.Op):
def infer_shape(self, node, ishapes):
return [(len(ishapes),)]
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if 'int' in self.dtype:
......@@ -1036,7 +1036,7 @@ class ShapeOptimizer(Optimizer):
Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.extend(ShapeFeature())
fgraph.attach_feature(ShapeFeature())
def apply(self, fgraph):
pass
......@@ -4583,8 +4583,8 @@ class FusionOptimizer(Optimizer):
self.optimizer = local_optimizer
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.extend(DestroyHandler())
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph):
did_something = True
......
"""
This file implement specialization optimization that break the canonization form of the graph.
Currently their is problem with the order of optimization and the definition of definition of canonized graph.
Currently there is problem with the order of optimization and the definition of definition of
canonized graph.
Right now their is a canonization optimization phase that try to make all equivalent graph identical. This is not always the case, but it do many of the basic stuff canonical. We need to extend the definition of canonization to make this true more often.
Right now there is a canonization optimization phase that try to make all equivalent graph
identical. This is not always the case, but it do many of the basic stuff canonical. We
need to extend the definition of canonization to make this true more often.
The problem this file indent to fix in the future is that in the "Equilibrium" specialization optimization phase, there is optimization that request that the graph is canonical, some other request that this is not true, and some other that break the canonicalization for some optimization. As we can't control the order of those optimization, their is case that some optimization requesting a canonical graph won't be applied as optimization that break the canonicalization form of the graph executed before.
The problem this file indent to fix in the future is that in the "Equilibrium" specialization
optimization phase, there is optimization that request that the graph is canonical, some other
request that this is not true, and some other that break the canonicalization for
some optimization. As we can't control the order of those optimization, there is case that some
optimization requesting a canonical graph won't be applied as optimization that break the
canonicalization form of the graph executed before.
To fix this, we need to split the specialization phase into a phase where optimization can't break the canonicalization form and one where this is allowed. This is also needed for the stabilized optimization phase, but as it happen before the specialization phase, this cause less problem.
......@@ -45,7 +53,7 @@ class MaxAndArgmaxOptimizer(Optimizer):
"""
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
did_something = True
......
......@@ -72,7 +72,7 @@ def shape_of_variables(fgraph, input_shapes):
"""
if not hasattr(fgraph, 'shape_feature'):
fgraph.extend(theano.tensor.opt.ShapeFeature())
fgraph.attach_feature(theano.tensor.opt.ShapeFeature())
input_dims = [dimension for inp in fgraph.inputs
for dimension in fgraph.shape_feature.shape_of[inp]]
......
......@@ -393,7 +393,7 @@ class T_extending(unittest.TestCase):
class Simplify(gof.Optimizer):
def add_requirements(self, fgraph):
fgraph.extend(toolbox.ReplaceValidate())
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
for node in fgraph.toposort():
if node.op == div:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论