提交 b39d1749 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/ops.py

上级 9eced720
...@@ -71,11 +71,12 @@ class ViewOp(gof.Op): ...@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for ViewOp, but it has " warnings.warn("Type %s has C code for ViewOp, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_view_op_c_code." % t, "arg when calling register_view_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op): ...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for DeepCopyOp, but it has " warnings.warn("Type %s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword"
"when calling register_deep_copy_op_c_code." % t, " arg when calling "
"register_deep_copy_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -284,11 +287,12 @@ class Shape(gof.Op): ...@@ -284,11 +287,12 @@ class Shape(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Shape, but it has " warnings.warn("Type %s has C code for Shape, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_shape_c_code." % t, "arg when calling register_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -301,7 +305,6 @@ class Shape(gof.Op): ...@@ -301,7 +305,6 @@ class Shape(gof.Op):
shape = Shape() shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape'))
class Shape_i(gof.Op): class Shape_i(gof.Op):
...@@ -389,8 +392,11 @@ class Shape_i(gof.Op): ...@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
return [()] return [()]
def grad(self, inp, grads): def grad(self, inp, grads):
return [theano.gradient.grad_not_implemented(op=self, x_pos=0, x=inp[0], return [theano.gradient.grad_not_implemented(
comment="No gradient for the shape of a matrix is implemented.")] op=self, x_pos=0, x=inp[0],
comment=("No gradient for the shape of a matrix "
"is implemented."))]
def shape_i(var, i, fgraph=None): def shape_i(var, i, fgraph=None):
"""Equivalent of var.shape[i], but apply if possible the shape """Equivalent of var.shape[i], but apply if possible the shape
...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None): ...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
def register_shape_i_c_code(typ, code, check_input, version=()): def register_shape_i_c_code(typ, code, check_input, version=()):
""" Tell Shape_i how to generate C code for a Theano Type """ Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and not
instance of the class. an instance of the class.
:param code: C code that gets the shape of dimensions %(i)s for the Theano type 'typ'. :param code: C code that gets the shape of dimensions %(i)s for the
Theano type 'typ'.
Use %(iname)s and %(oname)s for the input and output C Use %(iname)s and %(oname)s for the input and output C
variable names respectively. variable names respectively.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op): ...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
return type(self) == type(other) and self.axis == other.axis return type(self) == type(other) and self.axis == other.axis
def __hash__(self): def __hash__(self):
items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique # no ambiguity because each item key is unique
items = sorted(self.axis.iteritems())
return hash((type(self), tuple(items))) return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op): ...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())): if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension') raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.clone(broadcastable=[self.axis.get(i, b) t = x.type.clone(
for i, b in enumerate( broadcastable=[self.axis.get(i, b)
x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
return gof.Apply(self, [x], [t()]) return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op): ...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
for t, (c, v) in sorted(self.c_code_and_version.items(), for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Rebroadcast, but it has " warnings.warn("Type %s has C code for Rebroadcast, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_rebroadcast_c_code." % t, "keyword arg when calling "
"register_rebroadcast_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(), ...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply=None): c_support_code_apply=None):
""" Tell SpecifyShape how to generate C code for a Theano Type """ Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and
instance of the class. not an instance of the class.
:param code: C code that checks the shape and returns a view for the Theano type 'typ'. :param code: C code that checks the shape and returns a view for
Use %(iname)s and %(oname)s for the input and output C the Theano type 'typ'. Use %(iname)s and %(oname)s
variable names respectively. for the input and output C variable names
%(shape)s is the vector of shape of %(iname)s. respectively. %(shape)s is the vector of shape of
Check that its length is good. %(iname)s. Check that its length is good.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code. :param c_support_code_apply: extra code.
""" """
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) SpecifyShape.c_code_and_version[typ] = (code, version,
c_support_code_apply)
class SpecifyShape(gof.Op): class SpecifyShape(gof.Op):
...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op): ...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
new_shape = [] new_shape = []
for dim in xrange(node.inputs[0].ndim): for dim in xrange(node.inputs[0].ndim):
try: try:
s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim]) s = theano.tensor.get_scalar_constant_value(
node.inputs[1][dim])
s = theano.tensor.as_tensor_variable(s) s = theano.tensor.as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op): ...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
code, version, _ = self.c_code_and_version[itype] code, version, _ = self.c_code_and_version[itype]
return code % locals() return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub) return super(SpecifyShape, self).c_code(node, node, inames,
onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op): ...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
for t, (c, v, _) in sorted(self.c_code_and_version.items(), for t, (c, v, _) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for SpecifyShape, but it has " warnings.warn("Type %s has C code for SpecifyShape, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_specify_shape_c_code." % t, "keyword arg when calling "
"register_specify_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
......
...@@ -38,7 +38,6 @@ whitelist_flake8 = [ ...@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py", "tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/ops.py",
"compile/debugmode.py", "compile/debugmode.py",
"compile/function.py", "compile/function.py",
"compile/pfunc.py", "compile/pfunc.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论