提交 29647275 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3153 from harlouci/flake8_gof_sequel

Flake8 gof sequel
...@@ -7,18 +7,17 @@ To read about what theano graphs are from a user perspective, have a look at ...@@ -7,18 +7,17 @@ To read about what theano graphs are from a user perspective, have a look at
""" """
from __future__ import print_function from __future__ import print_function
__docformat__ = "restructuredtext en"
from collections import deque from collections import deque
from copy import copy from copy import copy
from itertools import count from itertools import count
import theano import theano
import warnings
from theano.gof import utils from theano.gof import utils
from six import string_types, integer_types, iteritems from six import string_types, integer_types, iteritems
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en"
# Lazy imports to avoid circular dependencies. # Lazy imports to avoid circular dependencies.
is_same_graph_with_merge = None is_same_graph_with_merge = None
equal_computations = None equal_computations = None
...@@ -310,7 +309,7 @@ class Variable(Node): ...@@ -310,7 +309,7 @@ class Variable(Node):
`compile.function` uses each `Apply` instance's `inputs` attribute `compile.function` uses each `Apply` instance's `inputs` attribute
together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs. together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs.
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] # __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0) __count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
...@@ -409,7 +408,7 @@ class Variable(Node): ...@@ -409,7 +408,7 @@ class Variable(Node):
self._fn_cache = dict() self._fn_cache = dict()
inputs = tuple(sorted(inputs_to_values.keys(), key=id)) inputs = tuple(sorted(inputs_to_values.keys(), key=id))
if not inputs in self._fn_cache: if inputs not in self._fn_cache:
self._fn_cache[inputs] = theano.function(inputs, self) self._fn_cache[inputs] = theano.function(inputs, self)
args = [inputs_to_values[param] for param in inputs] args = [inputs_to_values[param] for param in inputs]
...@@ -429,7 +428,7 @@ class Constant(Variable): ...@@ -429,7 +428,7 @@ class Constant(Variable):
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc. Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
""" """
#__slots__ = ['data'] # __slots__ = ['data']
def __init__(self, type, data, name=None): def __init__(self, type, data, name=None):
"""Initialize self. """Initialize self.
...@@ -481,8 +480,7 @@ class Constant(Variable): ...@@ -481,8 +480,7 @@ class Constant(Variable):
raise ValueError("Constant instances cannot have an owner.") raise ValueError("Constant instances cannot have an owner.")
owner = property(lambda self: None, __set_owner) owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data, value = property(lambda self: self.data, doc='read-only data access method')
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None # index is not defined, because the `owner` attribute must necessarily be None
...@@ -654,9 +652,7 @@ def clone(i, o, copy_inputs=True): ...@@ -654,9 +652,7 @@ def clone(i, o, copy_inputs=True):
return [equiv[input] for input in i], [equiv[output] for output in o] return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(inputs, outputs, def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
copy_inputs_and_orphans=True,
memo=None):
""" """
Return a dictionary that maps from Variable and Apply nodes in the Return a dictionary that maps from Variable and Apply nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
...@@ -776,7 +772,8 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -776,7 +772,8 @@ def general_toposort(r_out, deps, debug_print=False,
rlist.append(node) rlist.append(node)
rset.add(node) rset.add(node)
for client in clients.get(node, []): for client in clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client] if a is not node] deps_cache[client] = [a for a in deps_cache[client]
if a is not node]
if not deps_cache[client]: if not deps_cache[client]:
sources.append(client) sources.append(client)
...@@ -818,7 +815,7 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -818,7 +815,7 @@ def io_toposort(inputs, outputs, orderings=None):
# Also include the cache in the function itself for speed up. # Also include the cache in the function itself for speed up.
def compute_deps_cache(obj): def compute_deps_cache(obj):
if obj in deps_cache: if obj in deps_cache:
return deps_cache[io] return deps_cache[obj]
rval = [] rval = []
if obj not in iset: if obj not in iset:
if isinstance(obj, Variable): if isinstance(obj, Variable):
...@@ -858,8 +855,10 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -858,8 +855,10 @@ def io_toposort(inputs, outputs, orderings=None):
default_leaf_formatter = str default_leaf_formatter = str
default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
", ".join(argstrings))
def default_node_formatter(op, argstrings):
return "%s(%s)" % (op.op, ", ".join(argstrings))
def io_connection_pattern(inputs, outputs): def io_connection_pattern(inputs, outputs):
...@@ -874,7 +873,6 @@ def io_connection_pattern(inputs, outputs): ...@@ -874,7 +873,6 @@ def io_connection_pattern(inputs, outputs):
# connected only to itself # connected only to itself
connect_pattern_by_var = {} connect_pattern_by_var = {}
nb_inputs = len(inputs) nb_inputs = len(inputs)
nb_outputs = len(outputs)
for i in range(nb_inputs): for i in range(nb_inputs):
input = inputs[i] input = inputs[i]
...@@ -1153,5 +1151,5 @@ def list_of_nodes(inputs, outputs): ...@@ -1153,5 +1151,5 @@ def list_of_nodes(inputs, outputs):
return stack_search( return stack_search(
deque([o.owner for o in outputs]), deque([o.owner for o in outputs]),
lambda o: [inp.owner for inp in o.inputs lambda o: [inp.owner for inp in o.inputs
if inp.owner if inp.owner and
and not any(i in inp.owner.outputs for i in inputs)]) not any(i in inp.owner.outputs for i in inputs)])
...@@ -4,14 +4,6 @@ The `Op` class is the base interface for all operations ...@@ -4,14 +4,6 @@ The `Op` class is the base interface for all operations
compatible with `gof`'s :doc:`graph` routines. compatible with `gof`'s :doc:`graph` routines.
""" """
__authors__ = "theano-dev"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import inspect import inspect
import logging import logging
import numpy import numpy
...@@ -32,6 +24,13 @@ from theano.gof import utils ...@@ -32,6 +24,13 @@ from theano.gof import utils
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
__authors__ = "theano-dev"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
class CLinkerObject(object): class CLinkerObject(object):
"""Standard elements of an Op or Type used with the CLinker """Standard elements of an Op or Type used with the CLinker
...@@ -224,8 +223,7 @@ class CLinkerOp(CLinkerObject): ...@@ -224,8 +223,7 @@ class CLinkerOp(CLinkerObject):
- `MethodNotDefined`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code' % self.__class__.__name__)
% self.__class__.__name__)
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
"""Return a tuple of integers indicating the version of this Op. """Return a tuple of integers indicating the version of this Op.
...@@ -278,8 +276,8 @@ class CLinkerOp(CLinkerObject): ...@@ -278,8 +276,8 @@ class CLinkerOp(CLinkerObject):
:Exceptions: :Exceptions:
- `MethodNotDefined`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.MethodNotDefined('%s.c_code_cleanup' \ raise utils.MethodNotDefined('%s.c_code_cleanup' %
% self.__class__.__name__) self.__class__.__name__)
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
"""Optional: Return utility code for use by an `Op` that will be """Optional: Return utility code for use by an `Op` that will be
...@@ -720,7 +718,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -720,7 +718,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
if (any(is_f16(i.type) for i in node.inputs) or if (any(is_f16(i.type) for i in node.inputs) or
any(is_f16(o.type) for o in node.outputs)): any(is_f16(o.type) for o in node.outputs)):
print ("Disabling C code for %s due to unsupported " print("Disabling C code for %s due to unsupported "
"float16" % (self,)) "float16" % (self,))
raise NotImplementedError("float16") raise NotImplementedError("float16")
e = FunctionGraph(node.inputs, node.outputs) e = FunctionGraph(node.inputs, node.outputs)
...@@ -766,6 +764,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -766,6 +764,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
return r return r
else: else:
ctx_val = node.context_type.filter(ctx) ctx_val = node.context_type.filter(ctx)
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node, def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
ctx=ctx_val): ctx=ctx_val):
r = p(n, [x[0] for x in i], o, ctx) r = p(n, [x[0] for x in i], o, ctx)
...@@ -1144,9 +1143,9 @@ class COp(Op): ...@@ -1144,9 +1143,9 @@ class COp(Op):
n = 1 n = 1
while n < len(split): while n < len(split):
if split[n] == 'APPLY': if split[n] == 'APPLY':
self.code_sections['support_code_apply'] = split[n+1] self.code_sections['support_code_apply'] = split[n + 1]
elif split[n] == 'SUPPORT': elif split[n] == 'SUPPORT':
self.code_sections['support_code'] = split[n+1] self.code_sections['support_code'] = split[n + 1]
n += 2 n += 2
continue continue
...@@ -1167,7 +1166,7 @@ class COp(Op): ...@@ -1167,7 +1166,7 @@ class COp(Op):
(self.func_files[i], split[n])) (self.func_files[i], split[n]))
if split[n] not in self.code_sections: if split[n] not in self.code_sections:
self.code_sections[split[n]] = "" self.code_sections[split[n]] = ""
self.code_sections[split[n]] += split[n+1] self.code_sections[split[n]] += split[n + 1]
n += 2 n += 2
else: else:
...@@ -1270,12 +1269,12 @@ class COp(Op): ...@@ -1270,12 +1269,12 @@ class COp(Op):
def get_sub_macros(self, sub): def get_sub_macros(self, sub):
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
define_macros.append("#define FAIL %s" % define_macros.append("#define FAIL %s" % (
(self._lquote_macro(sub['fail']),)) self._lquote_macro(sub['fail']),))
undef_macros.append("#undef FAIL") undef_macros.append("#undef FAIL")
if 'context' in sub: if 'context' in sub:
define_macros.append("#define CONTEXT %s" % (sub['context'],)) define_macros.append("#define CONTEXT %s" % (sub['context'],))
undef_macos.append("#undef CONTEXT") undef_macros.append("#undef CONTEXT")
return os.linesep.join(define_macros), os.linesep.join(undef_macros) return os.linesep.join(define_macros), os.linesep.join(undef_macros)
...@@ -1308,25 +1307,24 @@ class COp(Op): ...@@ -1308,25 +1307,24 @@ class COp(Op):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
if self.func_name is not None: if self.func_name is not None:
assert 'code' not in self.code_sections assert 'code' not in self.code_sections
func_name = self.func_name
func_args = self.format_c_function_args(inp, out)
fail = sub['fail']
define_macros, undef_macros = self.get_c_macros(node, name, define_macros, undef_macros = self.get_c_macros(node, name,
check_input=False) check_input=False)
# Generate the C code # Generate the C code
return """ return """
%(define_macros)s %(define_macros)s
{ {
if (%(func_name)s(%(func_args)s) != 0) { if (%(func_name)s(%(func_args)s) != 0) {
%(fail)s %(fail)s
} }
} }
%(undef_macros)s %(undef_macros)s
""" % dict(func_name=self.func_name, fail=sub['fail'], """ % dict(func_name=self.func_name,
fail=sub['fail'],
func_args=self.format_c_function_args(inp, out), func_args=self.format_c_function_args(inp, out),
define_macros=define_macros, undef_macros=undef_macros) define_macros=define_macros,
undef_macros=undef_macros)
else: else:
if 'code' in self.code_sections: if 'code' in self.code_sections:
op_code = self.code_sections['code'] op_code = self.code_sections['code']
...@@ -1348,7 +1346,7 @@ class COp(Op): ...@@ -1348,7 +1346,7 @@ class COp(Op):
def_macros, undef_macros = self.get_c_macros(node, name) def_macros, undef_macros = self.get_c_macros(node, name)
def_sub, undef_sub = self.get_sub_macros(sub) def_sub, undef_sub = self.get_sub_macros(sub)
def_io, undef_io = self.get_io_macros(inp, out) def_io, undef_io = self.get_io_macros(inputs, outputs)
return os.linesep.join([def_macros, def_sub, def_io, return os.linesep.join([def_macros, def_sub, def_io,
op_code, op_code,
......
...@@ -11,7 +11,7 @@ from __future__ import print_function ...@@ -11,7 +11,7 @@ from __future__ import print_function
from copy import copy from copy import copy
from functools import partial from functools import partial
from theano.gof.utils import * from theano.gof.utils import ANY_TYPE, comm_guard, FALL_THROUGH, iteritems
################################ ################################
...@@ -35,10 +35,12 @@ class Variable: ...@@ -35,10 +35,12 @@ class Variable:
""" """
def __init__(self, name="?"): def __init__(self, name="?"):
self.name = name self.name = name
def __str__(self): def __str__(self):
return (self.__class__.__name__ + "(" + return (self.__class__.__name__ + "(" +
", ".join("%s=%s" % (key, value) ", ".join("%s=%s" % (key, value)
for key, value in iteritems(self.__dict__)) + ")") for key, value in iteritems(self.__dict__)) + ")")
def __repr__(self): def __repr__(self):
return str(self) return str(self)
...@@ -348,8 +350,9 @@ def unify_walk(a, b, U): ...@@ -348,8 +350,9 @@ def unify_walk(a, b, U):
Checks for the existence of the __unify_walk__ method for one of Checks for the existence of the __unify_walk__ method for one of
the objects. the objects.
""" """
if not isinstance(a, Variable) and not isinstance(b, Variable) \ if (not isinstance(a, Variable) and
and hasattr(a, "__unify_walk__"): not isinstance(b, Variable) and
hasattr(a, "__unify_walk__")):
return a.__unify_walk__(b, U) return a.__unify_walk__(b, U)
else: else:
return FALL_THROUGH return FALL_THROUGH
...@@ -430,8 +433,9 @@ def unify_merge(vs, o, U): ...@@ -430,8 +433,9 @@ def unify_merge(vs, o, U):
@comm_guard(ANY_TYPE, ANY_TYPE) @comm_guard(ANY_TYPE, ANY_TYPE)
def unify_merge(a, b, U): def unify_merge(a, b, U):
if not isinstance(a, Variable) and not isinstance(b, Variable) \ if (not isinstance(a, Variable) and
and hasattr(a, "__unify_merge__"): not isinstance(b, Variable) and
hasattr(a, "__unify_merge__")):
return a.__unify_merge__(b, U) return a.__unify_merge__(b, U)
else: else:
return FALL_THROUGH return FALL_THROUGH
...@@ -502,4 +506,3 @@ if __name__ == "__main__": ...@@ -502,4 +506,3 @@ if __name__ == "__main__":
U = unify_walk((1, 2), (va, va), Unification()) U = unify_walk((1, 2), (va, va), Unification())
print(U[va]) print(U[va])
...@@ -216,9 +216,7 @@ whitelist_flake8 = [ ...@@ -216,9 +216,7 @@ whitelist_flake8 = [
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/unify.py", "gof/unify.py",
"gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/op.py",
"gof/tests/test_cmodule.py", "gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py", "gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py", "gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论