提交 afcb5350 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/blas.py; two E left

上级 715c215a
......@@ -164,7 +164,7 @@ def default_blas_ldflags():
global numpy
try:
if (hasattr(numpy.distutils, '__config__') and
numpy.distutils.__config__):
numpy.distutils.__config__):
# If the old private interface is available use it as it
# don't print information to the user.
blas_info = numpy.distutils.__config__.blas_opt_info
......@@ -319,8 +319,8 @@ SOMEPATH/Canopy_64bit/User/lib/python2.7/site-packages/numpy/distutils/system_in
AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags))
"lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags))
try:
......@@ -333,12 +333,10 @@ try:
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas
_blas_gemv_fns = {
numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv,
}
_blas_gemv_fns = {numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv}
except ImportError as e:
have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer
......@@ -400,8 +398,8 @@ class Gemv(Op):
# The following is not grounds for error because as long as
# sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type))
# raise TypeError('broadcastable mismatch between x and A',
# (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
......@@ -411,9 +409,10 @@ class Gemv(Op):
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))
raise ValueError(
'Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))
# Here I suppose that A is in c order. If we don't make it
# explicitly as fortran order, scipy 0.7.2 seam to create
......@@ -479,7 +478,7 @@ class Ger(Op):
alpha = T.as_tensor_variable(alpha)
if len(set([A.dtype, alpha.dtype, x.dtype, y.dtype])) != 1:
raise TypeError('ger requires matching dtypes',
(A.dtype, alpha.dtype, x.dtype, y.dtype))
(A.dtype, alpha.dtype, x.dtype, y.dtype))
if alpha.ndim != 0:
raise TypeError('ger requires scalar alpha', alpha.type)
if A.ndim != 2:
......@@ -567,13 +566,14 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
for d in dirs:
for f in os.listdir(d):
if (f.endswith('.so') or f.endswith('.dylib') or
f.endswith('.dll')):
f.endswith('.dll')):
if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True
if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.")
_logger.warning(
"We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.")
for t in ldflags_str.split():
# Remove extra quote.
......@@ -644,7 +644,7 @@ class GemmRelated(Op):
return ldflags()
# code_cache_version is built by subclasses from
# build_gemm_version
# build_gemm_version
def c_compile_args(self):
return ldflags(libs=False, flags=True)
......@@ -673,7 +673,7 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
#setup_z_Nz_Sz = None
# setup_z_Nz_Sz = None
check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) {
......@@ -823,7 +823,7 @@ class GemmRelated(Op):
{
"""
#case_float_ab_constants = None
# case_float_ab_constants = None
case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s);
......@@ -856,7 +856,7 @@ class GemmRelated(Op):
{
"""
#case_double_ab_constants = None
# case_double_ab_constants = None
case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s);
......@@ -1028,10 +1028,10 @@ class Gemm(GemmRelated):
if not (z.dtype == a.dtype == x.dtype == y.dtype == b.dtype):
raise TypeError(Gemm.E_mixed,
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
if (not z.dtype.startswith('float')
and not z.dtype.startswith('complex')):
if (not z.dtype.startswith('float') and
not z.dtype.startswith('complex')):
raise TypeError(Gemm.E_float, (z.dtype))
output = z.type()
......@@ -1173,8 +1173,8 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name,
(_z, _a, _x, _y, _b), (_zout, ),
......@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None):
else:
retval = True
return node.owner \
and node.owner.op == op \
and retval
return (node.owner and
node.owner.op == op and
retval)
def _as_scalar(res, dtype=None):
......@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 2 \
and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False # cope with tuple vs. list
return (res.type.dtype in ('float32', 'float64') and
res.type.ndim == 2 and
res.type.broadcastable[0] == False and
res.type.broadcastable[1] == False) # cope with tuple vs. list
def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \
and res.type.broadcastable[0] == False
return (res.type.dtype in ('float32', 'float64') and
res.type.ndim == 1 and
res.type.broadcastable[0] == False)
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
......@@ -1262,8 +1262,8 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
if (M.owner and isinstance(M.owner.op, T.DimShuffle) and
M.owner.inputs[0].owner and
isinstance(M.owner.inputs[0].owner.op, Dot22)):
M.owner.inputs[0].owner and
isinstance(M.owner.inputs[0].owner.op, Dot22)):
MM = M.owner.inputs[0]
if M.owner.op.new_order == (0,):
# it is making a column MM into a vector
......@@ -1493,7 +1493,7 @@ def _gemm_from_factored_list(lst):
assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i, j)]
for k, input in enumerate(lst) if k not in (i, j)]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [T.add(*add_inputs)]
......@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))):
continue
if not node in fgraph.apply_nodes:
if node not in fgraph.apply_nodes:
# This mean that we already removed this node from
# the graph
continue
......@@ -1725,8 +1725,8 @@ class Dot22(GemmRelated):
_x, _y = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
......@@ -1895,17 +1895,16 @@ blas_optdb.register('local_dot_to_dot22',
in2out(local_dot_to_dot22),
0, 'fast_run', 'fast_compile')
blas_optdb.register('gemm_optimizer',
GemmOptimizer(),
10, 'fast_run')
GemmOptimizer(),
10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([
local_gemm_to_gemv,
local_gemm_to_ger,
local_dot22_to_ger_or_gemv,
local_dimshuffle_lift],
max_use_ratio=5,
ignore_newtrees=False),
15, 'fast_run')
EquilibriumOptimizer([local_gemm_to_gemv,
local_gemm_to_ger,
local_dot22_to_ger_or_gemv,
local_dimshuffle_lift],
max_use_ratio=5,
ignore_newtrees=False),
15, 'fast_run')
# After destroyhandler(49.5) but before we try to make elemwise things
......@@ -1936,12 +1935,12 @@ class Dot22Scalar(GemmRelated):
if not (a.dtype == x.dtype == y.dtype):
raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype))
(a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float')
and not a.dtype.startswith('complex')):
if (not a.dtype.startswith('float') and
not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args',
a.dtype)
a.dtype)
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
......@@ -1992,8 +1991,8 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
(_zout, ), sub)
......@@ -2051,7 +2050,7 @@ def local_dot22_to_dot22scalar(node):
# The canonizer should have merged those mul together.
i_mul = [x.owner and x.owner.op == T.mul and
any([_as_scalar(x_i, dtype=d.dtype)
for x_i in x.owner.inputs])
for x_i in x.owner.inputs])
for x in node.inputs]
if not any(i_mul):
# no scalar in input and no multiplication
......@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1
for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype)
== d.type.dtype):
x.type.dtype, d.type.dtype) == d.type.dtype):
scalar_idx = i
break
......@@ -2103,8 +2101,8 @@ def local_dot22_to_dot22scalar(node):
break
if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs])
return False
assert scalar_idx < len(node.inputs)
s = node.inputs[scalar_idx]
......@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run')
#from opt import register_specialize, register_canonicalize
#@register_specialize
# from opt import register_specialize, register_canonicalize
# @register_specialize
@local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add):
......
......@@ -58,7 +58,6 @@ whitelist_flake8 = [
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/__init__.py",
"tensor/blas.py",
"tensor/extra_ops.py",
"tensor/nlinalg.py",
"tensor/blas_c.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论