提交 064ce7a0 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8 fixes

上级 cc4ce0d0
...@@ -71,9 +71,10 @@ class Apply(utils.object2): ...@@ -71,9 +71,10 @@ class Apply(utils.object2):
self.inputs = [] self.inputs = []
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
if not isinstance(inputs,(list,tuple)): if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple") raise TypeError("The inputs of an Apply must be a list or tuple")
if not isinstance(outputs,(list,tuple)):
if not isinstance(outputs, (list, tuple)):
raise TypeError("The output of an Apply must be a list or tuple") raise TypeError("The output of an Apply must be a list or tuple")
## filter inputs to make sure each element is a Variable ## filter inputs to make sure each element is a Variable
...@@ -120,27 +121,25 @@ class Apply(utils.object2): ...@@ -120,27 +121,25 @@ class Apply(utils.object2):
raise AttributeError("%s.default_output is out of range." % self.op) raise AttributeError("%s.default_output is out of range." % self.op)
return self.outputs[do] return self.outputs[do]
def env_getter(self): def env_getter(self):
warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
return self.fgraph return self.fgraph
def env_setter(self,value): def env_setter(self, value):
warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
self.fgraph = value self.fgraph = value
def env_deleter(self): def env_deleter(self):
warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Apply.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
del self.fgraph del self.fgraph
env = property(env_getter, env_setter, env_deleter) env = property(env_getter, env_setter, env_deleter)
out = property(default_output, out = property(default_output,
doc = "alias for self.default_output()") doc="alias for self.default_output()")
"""Alias for self.default_output()""" """Alias for self.default_output()"""
def __str__(self): def __str__(self):
...@@ -165,7 +164,7 @@ class Apply(utils.object2): ...@@ -165,7 +164,7 @@ class Apply(utils.object2):
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def clone_with_new_inputs(self, inputs, strict = True): def clone_with_new_inputs(self, inputs, strict=True):
"""Duplicate this Apply instance in a new graph. """Duplicate this Apply instance in a new graph.
:param inputs: list of Variable instances to use as inputs. :param inputs: list of Variable instances to use as inputs.
...@@ -204,10 +203,10 @@ class Apply(utils.object2): ...@@ -204,10 +203,10 @@ class Apply(utils.object2):
return new_node return new_node
#convenience properties #convenience properties
nin = property(lambda self: len(self.inputs), doc = 'same as len(self.inputs)') nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)')
"""property: Number of inputs""" """property: Number of inputs"""
nout = property(lambda self: len(self.outputs), doc = 'same as len(self.outputs)') nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)')
"""property: Number of outputs""" """property: Number of outputs"""
...@@ -289,7 +288,7 @@ class Variable(utils.object2): ...@@ -289,7 +288,7 @@ class Variable(utils.object2):
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner=None, index=None, name=None):
"""Initialize type, owner, index, name. """Initialize type, owner, index, name.
:type type: a Type instance :type type: a Type instance
...@@ -317,6 +316,7 @@ class Variable(utils.object2): ...@@ -317,6 +316,7 @@ class Variable(utils.object2):
if name is not None and not isinstance(name, basestring): if name is not None and not isinstance(name, basestring):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
def __str__(self): def __str__(self):
"""WRITEME""" """WRITEME"""
if self.name is not None: if self.name is not None:
...@@ -329,8 +329,10 @@ class Variable(utils.object2): ...@@ -329,8 +329,10 @@ class Variable(utils.object2):
return str(self.owner.op) + "." + str(self.index) return str(self.owner.op) + "." + str(self.index)
else: else:
return "<%s>" % str(self.type) return "<%s>" % str(self.type)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def clone(self): def clone(self):
"""Return a new Variable like self. """Return a new Variable like self.
...@@ -345,39 +347,40 @@ class Variable(utils.object2): ...@@ -345,39 +347,40 @@ class Variable(utils.object2):
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def __lt__(self,other): def __lt__(self, other):
raise NotImplementedError('Subclasses of Variable must provide __lt__', raise NotImplementedError('Subclasses of Variable must provide __lt__',
self.__class__.__name__) self.__class__.__name__)
def __le__(self,other):
def __le__(self, other):
raise NotImplementedError('Subclasses of Variable must provide __le__', raise NotImplementedError('Subclasses of Variable must provide __le__',
self.__class__.__name__) self.__class__.__name__)
def __gt__(self,other):
def __gt__(self, other):
raise NotImplementedError('Subclasses of Variable must provide __gt__', raise NotImplementedError('Subclasses of Variable must provide __gt__',
self.__class__.__name__) self.__class__.__name__)
def __ge__(self,other):
def __ge__(self, other):
raise NotImplementedError('Subclasses of Variable must provide __ge__', raise NotImplementedError('Subclasses of Variable must provide __ge__',
self.__class__.__name__) self.__class__.__name__)
def env_getter(self): def env_getter(self):
warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
return self.fgraph return self.fgraph
def env_setter(self,value): def env_setter(self, value):
warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
self.fgraph = value self.fgraph = value
def env_deleter(self): def env_deleter(self):
warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'", warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'",
stacklevel = 2) stacklevel=2)
del self.fgraph del self.fgraph
env = property(env_getter, env_setter, env_deleter) env = property(env_getter, env_setter, env_deleter)
class Constant(Variable): class Constant(Variable):
""" """
A :term:`Constant` is a `Variable` with a `value` field that cannot be changed at runtime. A :term:`Constant` is a `Variable` with a `value` field that cannot be changed at runtime.
...@@ -385,7 +388,7 @@ class Constant(Variable): ...@@ -385,7 +388,7 @@ class Constant(Variable):
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc. Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
""" """
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name=None):
"""Initialize self. """Initialize self.
:note: :note:
...@@ -439,7 +442,7 @@ class Constant(Variable): ...@@ -439,7 +442,7 @@ class Constant(Variable):
# index is not defined, because the `owner` attribute must necessarily be None # index is not defined, because the `owner` attribute must necessarily be None
def stack_search(start, expand, mode='bfs', build_inv = False): def stack_search(start, expand, mode='bfs', build_inv=False):
"""Search through a graph, either breadth- or depth-first """Search through a graph, either breadth- or depth-first
:type start: deque :type start: deque
...@@ -591,7 +594,7 @@ def orphans(i, o): ...@@ -591,7 +594,7 @@ def orphans(i, o):
return variables_and_orphans(i, o)[1] return variables_and_orphans(i, o)[1]
def clone(i, o, copy_inputs = True): def clone(i, o, copy_inputs=True):
""" """
Copies the subgraph contained between i and o. Copies the subgraph contained between i and o.
...@@ -670,7 +673,7 @@ def clone_get_equiv(inputs, outputs, ...@@ -670,7 +673,7 @@ def clone_get_equiv(inputs, outputs,
return memo return memo
def general_toposort(r_out, deps, debug_print = False): def general_toposort(r_out, deps, debug_print=False):
"""WRITEME """WRITEME
:note: :note:
...@@ -683,6 +686,7 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -683,6 +686,7 @@ def general_toposort(r_out, deps, debug_print = False):
The order of the return value list is determined by the order of nodes returned by the deps() function. The order of the return value list is determined by the order of nodes returned by the deps() function.
""" """
deps_cache = {} deps_cache = {}
def _deps(io): def _deps(io):
if io not in deps_cache: if io not in deps_cache:
d = deps(io) d = deps(io)
...@@ -696,7 +700,7 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -696,7 +700,7 @@ def general_toposort(r_out, deps, debug_print = False):
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search( deque(r_out), _deps, 'dfs', True) reachable, clients = stack_search(deque(r_out), _deps, 'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)]) sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set() rset = set()
...@@ -728,6 +732,7 @@ def io_toposort(i, o, orderings=None): ...@@ -728,6 +732,7 @@ def io_toposort(i, o, orderings=None):
orderings = {} orderings = {}
#the inputs are used only here in the function that decides what 'predecessors' to explore #the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i) iset = set(i)
def deps(obj): def deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
...@@ -740,6 +745,7 @@ def io_toposort(i, o, orderings=None): ...@@ -740,6 +745,7 @@ def io_toposort(i, o, orderings=None):
else: else:
assert not orderings.get(obj, []) assert not orderings.get(obj, [])
return rval return rval
topo = general_toposort(o, deps) topo = general_toposort(o, deps)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
...@@ -817,9 +823,11 @@ def is_same_graph(var1, var2, givens=None, debug=False): ...@@ -817,9 +823,11 @@ def is_same_graph(var1, var2, givens=None, debug=False):
all_vars = [set(variables(v_i, v_o)) all_vars = [set(variables(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), for v_i, v_o in ((inputs_var[0], [var1]),
(inputs_var[1], [var2]))] (inputs_var[1], [var2]))]
def in_var(x, k): def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`. # Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1] return x in all_vars[k - 1]
for to_replace, replace_by in givens.iteritems(): for to_replace, replace_by in givens.iteritems():
# Map a substitution variable to the computational graphs it # Map a substitution variable to the computational graphs it
# belongs to. # belongs to.
...@@ -856,16 +864,16 @@ def is_same_graph(var1, var2, givens=None, debug=False): ...@@ -856,16 +864,16 @@ def is_same_graph(var1, var2, givens=None, debug=False):
def op_as_string(i, op, def op_as_string(i, op,
leaf_formatter = default_leaf_formatter, leaf_formatter=default_leaf_formatter,
node_formatter = default_node_formatter): node_formatter=default_node_formatter):
"""WRITEME""" """WRITEME"""
strs = as_string(i, op.inputs, leaf_formatter, node_formatter) strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs) return node_formatter(op, strs)
def as_string(i, o, def as_string(i, o,
leaf_formatter = default_leaf_formatter, leaf_formatter=default_leaf_formatter,
node_formatter = default_node_formatter): node_formatter=default_node_formatter):
"""WRITEME """WRITEME
:type i: list :type i: list
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论