提交 afcb5350 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/blas.py; two E left

上级 715c215a
......@@ -333,12 +333,10 @@ try:
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas
_blas_gemv_fns = {
numpy.dtype('float32'): fblas.sgemv,
_blas_gemv_fns = {numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv,
}
numpy.dtype('complex128'): fblas.zgemv}
except ImportError as e:
have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer
......@@ -401,7 +399,7 @@ class Gemv(Op):
# sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type))
# (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
......@@ -411,7 +409,8 @@ class Gemv(Op):
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv '
raise ValueError(
'Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))
......@@ -571,7 +570,8 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True
if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the "
_logger.warning(
"We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.")
......@@ -673,7 +673,7 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
#setup_z_Nz_Sz = None
# setup_z_Nz_Sz = None
check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) {
......@@ -823,7 +823,7 @@ class GemmRelated(Op):
{
"""
#case_float_ab_constants = None
# case_float_ab_constants = None
case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s);
......@@ -856,7 +856,7 @@ class GemmRelated(Op):
{
"""
#case_double_ab_constants = None
# case_double_ab_constants = None
case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s);
......@@ -1030,8 +1030,8 @@ class Gemm(GemmRelated):
raise TypeError(Gemm.E_mixed,
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
if (not z.dtype.startswith('float')
and not z.dtype.startswith('complex')):
if (not z.dtype.startswith('float') and
not z.dtype.startswith('complex')):
raise TypeError(Gemm.E_float, (z.dtype))
output = z.type()
......@@ -1173,7 +1173,7 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name,
......@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None):
else:
retval = True
return node.owner \
and node.owner.op == op \
and retval
return (node.owner and
node.owner.op == op and
retval)
def _as_scalar(res, dtype=None):
......@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 2 \
and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False # cope with tuple vs. list
return (res.type.dtype in ('float32', 'float64') and
res.type.ndim == 2 and
res.type.broadcastable[0] == False and
res.type.broadcastable[1] == False) # cope with tuple vs. list
def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \
and res.type.ndim == 1 \
and res.type.broadcastable[0] == False
return (res.type.dtype in ('float32', 'float64') and
res.type.ndim == 1 and
res.type.broadcastable[0] == False)
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
......@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))):
continue
if not node in fgraph.apply_nodes:
if node not in fgraph.apply_nodes:
# This mean that we already removed this node from
# the graph
continue
......@@ -1725,7 +1725,7 @@ class Dot22(GemmRelated):
_x, _y = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y),
......@@ -1898,8 +1898,7 @@ blas_optdb.register('gemm_optimizer',
GemmOptimizer(),
10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([
local_gemm_to_gemv,
EquilibriumOptimizer([local_gemm_to_gemv,
local_gemm_to_ger,
local_dot22_to_ger_or_gemv,
local_dimshuffle_lift],
......@@ -1938,8 +1937,8 @@ class Dot22Scalar(GemmRelated):
raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float')
and not a.dtype.startswith('complex')):
if (not a.dtype.startswith('float') and
not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args',
a.dtype)
......@@ -1992,7 +1991,7 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp
_zout, = out
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__)
if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
......@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1
for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype)
== d.type.dtype):
x.type.dtype, d.type.dtype) == d.type.dtype):
scalar_idx = i
break
......@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run')
#from opt import register_specialize, register_canonicalize
#@register_specialize
# from opt import register_specialize, register_canonicalize
# @register_specialize
@local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add):
......
......@@ -58,7 +58,6 @@ whitelist_flake8 = [
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/__init__.py",
"tensor/blas.py",
"tensor/extra_ops.py",
"tensor/nlinalg.py",
"tensor/blas_c.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论