提交 aca35acc authored 作者: Frederic's avatar Frederic

pep8

上级 2de0bc5e
...@@ -16,6 +16,7 @@ NullType = None ...@@ -16,6 +16,7 @@ NullType = None
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -82,7 +83,8 @@ class FunctionGraph(utils.object2): ...@@ -82,7 +83,8 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set. # so I probably am) this should be a set.
self._features = [] self._features = []
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field # All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self.apply_nodes = set() self.apply_nodes = set()
# Ditto for variable nodes # Ditto for variable nodes
...@@ -112,12 +114,12 @@ class FunctionGraph(utils.object2): ...@@ -112,12 +114,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
### Setup a Variable ### ### Setup a Variable ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if hasattr(r, 'fgraph') and r.fgraph is not None and r.fgraph is not self: if (hasattr(r, 'fgraph') and
r.fgraph is not None and
r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
r.clients = [] r.clients = []
...@@ -165,13 +167,13 @@ class FunctionGraph(utils.object2): ...@@ -165,13 +167,13 @@ class FunctionGraph(utils.object2):
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
### clients ### ### clients ###
def clients(self, r): def clients(self, r):
""" """
Set of all the (node, i) pairs such that node.inputs[i] is r. Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have r as input at index i. Tell differently, a list of (node,i) such that each node have
r as input at index i.
""" """
return r.clients return r.clients
...@@ -184,12 +186,14 @@ class FunctionGraph(utils.object2): ...@@ -184,12 +186,14 @@ class FunctionGraph(utils.object2):
""" """
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print >> sys.stderr, 'ERROR: clients intersect!' print >> sys.stderr, 'ERROR: clients intersect!'
print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients] print >> sys.stderr, ' RCLIENTS of', r, [(n, i, type(n), id(n))
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients] for n, i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients]
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True): def __remove_clients__(self, r, clients_to_remove, prune=True):
""" WRITEME """ WRITEME
r -> variable r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...@@ -202,7 +206,7 @@ class FunctionGraph(utils.object2): ...@@ -202,7 +206,7 @@ class FunctionGraph(utils.object2):
print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...' print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...'
print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0]) print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0])
print >> sys.stderr, ' CLIENTS', repr(r.clients) print >> sys.stderr, ' CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -210,9 +214,7 @@ class FunctionGraph(utils.object2): ...@@ -210,9 +214,7 @@ class FunctionGraph(utils.object2):
return True return True
return False return False
### import ### ### import ###
def __import_r__(self, variables): def __import_r__(self, variables):
global NullType global NullType
if NullType is None: if NullType is None:
...@@ -225,14 +227,15 @@ class FunctionGraph(utils.object2): ...@@ -225,14 +227,15 @@ class FunctionGraph(utils.object2):
self.__import__(apply_node) self.__import__(apply_node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType): if isinstance(r.type, NullType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_null) raise TypeError("Computation graph contains a NaN. " +
r.type.why_null)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self: if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
self.variables.add(r) self.variables.add(r)
def __import__(self, apply_node, check = True): def __import__(self, apply_node, check=True):
node = apply_node node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
...@@ -248,7 +251,9 @@ class FunctionGraph(utils.object2): ...@@ -248,7 +251,9 @@ class FunctionGraph(utils.object2):
for r in node.inputs: for r in node.inputs:
if hasattr(r, 'fgraph') and r.fgraph is not self: if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
#Verbose error message #Verbose error message
#Show a complete chain of variables from the missing input to an output #Show a complete chain of variables from the missing input to an output
...@@ -330,9 +335,7 @@ class FunctionGraph(utils.object2): ...@@ -330,9 +335,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node) self.execute_callbacks('on_import', node)
### prune ### ### prune ###
def __prune_r__(self, variables): def __prune_r__(self, variables):
# Prunes the owners of the variables. # Prunes the owners of the variables.
for node in set(r.owner for r in variables if r.owner is not None): for node in set(r.owner for r in variables if r.owner is not None):
...@@ -362,10 +365,7 @@ class FunctionGraph(utils.object2): ...@@ -362,10 +365,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(node, i)]) self.__remove_clients__(input, [(node, i)])
#self.__prune_r__(node.inputs) #self.__prune_r__(node.inputs)
### change input ### ### change input ###
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
"""WRITEME """WRITEME
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
...@@ -381,18 +381,18 @@ class FunctionGraph(utils.object2): ...@@ -381,18 +381,18 @@ class FunctionGraph(utils.object2):
r = self.outputs[i] r = self.outputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the" raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.", " same as the type of the original Variable.",
r, new_r) r, new_r)
self.outputs[i] = new_r self.outputs[i] = new_r
else: else:
if node.fgraph is not self: if node.fgraph is not self:
raise Exception("Cannot operate on %s because it does not" raise Exception("Cannot operate on %s because it does not"
" belong to this FunctionGraph" % node) " belong to this FunctionGraph" % node)
r = node.inputs[i] r = node.inputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the" raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.", " same as the type of the original Variable.",
r, new_r) r, new_r)
node.inputs[i] = new_r node.inputs[i] = new_r
if r is new_r: if r is new_r:
...@@ -404,7 +404,8 @@ class FunctionGraph(utils.object2): ...@@ -404,7 +404,8 @@ class FunctionGraph(utils.object2):
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the # However it may introduce cycles to the graph, in which case the
# transaction will be reverted later. # transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason) self.execute_callbacks('on_change_input', node, i,
r, new_r, reason=reason)
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -426,7 +427,7 @@ class FunctionGraph(utils.object2): ...@@ -426,7 +427,7 @@ class FunctionGraph(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
for node, i in list(r.clients): # copy the client list for iteration for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r) assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
...@@ -440,11 +441,9 @@ class FunctionGraph(utils.object2): ...@@ -440,11 +441,9 @@ class FunctionGraph(utils.object2):
for r, new_r in pairs: for r, new_r in pairs:
self.replace(r, new_r, reason=reason) self.replace(r, new_r, reason=reason)
def extend(self, feature): def extend(self, feature):
warnings.warn("FunctionGraph.extend is deprecatd. It has been " warnings.warn("FunctionGraph.extend is deprecatd. It has been "
"renamed to FunctionGraph.attach_feature") "renamed to FunctionGraph.attach_feature")
return self.attach_feature(feature) return self.attach_feature(feature)
def attach_feature(self, feature): def attach_feature(self, feature):
...@@ -455,7 +454,7 @@ class FunctionGraph(utils.object2): ...@@ -455,7 +454,7 @@ class FunctionGraph(utils.object2):
# Filter out literally identical features # Filter out literally identical features
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
# Filter out functionally identical features. # Filter out functionally identical features.
# Features may use their on_attach method to raise # Features may use their on_attach method to raise
...@@ -481,7 +480,9 @@ class FunctionGraph(utils.object2): ...@@ -481,7 +480,9 @@ class FunctionGraph(utils.object2):
"""WRITEME """WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method is defined. Calls feature.on_detach(function_graph) if an on_detach method
is defined.
""" """
try: try:
self._features.remove(feature) self._features.remove(feature)
...@@ -491,9 +492,7 @@ class FunctionGraph(utils.object2): ...@@ -491,9 +492,7 @@ class FunctionGraph(utils.object2):
if detach is not None: if detach is not None:
detach(self) detach(self)
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """WRITEME
Calls Calls
...@@ -518,7 +517,6 @@ class FunctionGraph(utils.object2): ...@@ -518,7 +517,6 @@ class FunctionGraph(utils.object2):
else: else:
raise raise
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """WRITEME
Returns a dictionary d such that: Returns a dictionary d such that:
...@@ -534,9 +532,7 @@ class FunctionGraph(utils.object2): ...@@ -534,9 +532,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args) d[feature] = fn(*args)
return d return d
### misc ### ### misc ###
def toposort(self): def toposort(self):
"""WRITEME """WRITEME
Returns an ordering of the graph's Apply nodes such that: Returns an ordering of the graph's Apply nodes such that:
...@@ -552,8 +548,8 @@ class FunctionGraph(utils.object2): ...@@ -552,8 +548,8 @@ class FunctionGraph(utils.object2):
if len(self.apply_nodes) < 2: if len(self.apply_nodes) < 2:
# optimization # optimization
# when there are 0 or 1 nodes, no sorting is necessary # when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces # This special case happens a lot because the OpWiseCLinker
# 1-element graphs. # produces 1-element graphs.
return list(self.apply_nodes) return list(self.apply_nodes)
fg = self fg = self
...@@ -568,14 +564,15 @@ class FunctionGraph(utils.object2): ...@@ -568,14 +564,15 @@ class FunctionGraph(utils.object2):
Return dict d s.t. d[node] is a list of nodes that must be evaluated Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated. before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all This is used primarily by the destroy_handler feature to ensure that
clients of any destroyed inputs have already computed their outputs. all clients of any destroyed inputs have already computed their
outputs.
:note: This only calls the orderings() fct on all features. It does not :note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself. take care of computing dependencies by itself.
""" """
ords = OrderedDict() ords = OrderedDict()
assert isinstance(self._features, list) assert isinstance(self._features, list)
for feature in self._features: for feature in self._features:
if hasattr(feature, 'orderings'): if hasattr(feature, 'orderings'):
...@@ -586,12 +583,13 @@ class FunctionGraph(utils.object2): ...@@ -586,12 +583,13 @@ class FunctionGraph(utils.object2):
+". Nondeterministic object is "+str(orderings)) +". Nondeterministic object is "+str(orderings))
for node, prereqs in orderings.items(): for node, prereqs in orderings.items():
if not isinstance(prereqs, (list, OrderedSet)): if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError("prereqs must be a type with a " raise TypeError(
"deterministic iteration order, or toposort " "prereqs must be a type with a "
" will be non-deterministic.") "deterministic iteration order, or toposort "
" will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs) ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs # eliminate duplicate prereqs
for (node,prereqs) in ords.items(): for (node, prereqs) in ords.items():
ords[node] = list(OrderedSet(prereqs)) ords[node] = list(OrderedSet(prereqs))
return ords return ords
...@@ -624,34 +622,48 @@ class FunctionGraph(utils.object2): ...@@ -624,34 +622,48 @@ class FunctionGraph(utils.object2):
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess) raise Exception(
"The nodes are inappropriately cached. missing, in excess: ",
missing, excess)
for node in nodes: for node in nodes:
if node.fgraph is not self: if node.fgraph is not self:
raise Exception("Node should belong to the FunctionGraph.", node) raise Exception("Node should belong to the FunctionGraph.",
node)
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Input of node should belong to the FunctionGraph.", variable, (node, i)) raise Exception(
"Input of node should belong to the FunctionGraph.",
variable, (node, i))
if (node, i) not in variable.clients: if (node, i) not in variable.clients:
raise Exception("Inconsistent clients list.", (node, i), variable.clients) raise Exception("Inconsistent clients list.",
(node, i), variable.clients)
variables = set(graph.variables(self.inputs, self.outputs)) variables = set(graph.variables(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
missing = variables.difference(self.variables) missing = variables.difference(self.variables)
excess = self.variables.difference(variables) excess = self.variables.difference(variables)
raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess) raise Exception(
"The variables are inappropriately cached. missing, in excess: ",
missing, excess)
for variable in variables: for variable in variables:
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant): if (variable.owner is None and
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.", variable) raise Exception("Variable should belong to the FunctionGraph.",
variable)
for node, i in variable.clients: for node, i in variable.clients:
if node == 'output': if node == 'output':
if self.outputs[i] is not variable: if self.outputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, self.outputs[i]) raise Exception("Inconsistent clients list.",
variable, self.outputs[i])
continue continue
if node not in nodes: if node not in nodes:
raise Exception("Client not in FunctionGraph.", variable, (node, i)) raise Exception("Client not in FunctionGraph.",
variable, (node, i))
if node.inputs[i] is not variable: if node.inputs[i] is not variable:
raise Exception("Inconsistent clients list.", variable, node.inputs[i]) raise Exception("Inconsistent clients list.",
variable, node.inputs[i])
def __str__(self): def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs)) return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
...@@ -659,9 +671,7 @@ class FunctionGraph(utils.object2): ...@@ -659,9 +671,7 @@ class FunctionGraph(utils.object2):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
### clone ### ### clone ###
def clone(self): def clone(self):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv()[0] return self.clone_get_equiv()[0]
...@@ -671,7 +681,7 @@ class FunctionGraph(utils.object2): ...@@ -671,7 +681,7 @@ class FunctionGraph(utils.object2):
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity() self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs])
e.check_integrity() e.check_integrity()
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
......
...@@ -3,11 +3,9 @@ import time ...@@ -3,11 +3,9 @@ import time
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.gof import graph from theano.gof import graph
class AlreadyThere(Exception): class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph """Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical attempting to attach the feature already has a functionally identical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论