提交 afcb5350 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/blas.py; two E left

上级 715c215a
...@@ -333,12 +333,10 @@ try: ...@@ -333,12 +333,10 @@ try:
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`. # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358 # See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas fblas = scipy.linalg.blas
_blas_gemv_fns = { _blas_gemv_fns = {numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv, numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv, numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv, numpy.dtype('complex128'): fblas.zgemv}
}
except ImportError as e: except ImportError as e:
have_fblas = False have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer # This is used in Gemv and ScipyGer. We use CGemv and CGer
...@@ -401,7 +399,7 @@ class Gemv(Op): ...@@ -401,7 +399,7 @@ class Gemv(Op):
# sizes are 1 at time of perform() there is no problem # sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]: # if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A', # raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type)) # (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
...@@ -411,7 +409,8 @@ class Gemv(Op): ...@@ -411,7 +409,8 @@ class Gemv(Op):
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv ' raise ValueError(
'Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s ' '(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape)) % (y.shape, A.shape, x.shape))
...@@ -571,7 +570,8 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): ...@@ -571,7 +570,8 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
if any([f.find(ll) >= 0 for ll in l]): if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True found_dyn = True
if not found_dyn and dirs: if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the " _logger.warning(
"We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use " "library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.") "ATLAS, make sure to compile it with dynamics library.")
...@@ -673,7 +673,7 @@ class GemmRelated(Op): ...@@ -673,7 +673,7 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
""" """
#setup_z_Nz_Sz = None # setup_z_Nz_Sz = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) { if (PyArray_NDIM(%(_x)s) != 2) {
...@@ -823,7 +823,7 @@ class GemmRelated(Op): ...@@ -823,7 +823,7 @@ class GemmRelated(Op):
{ {
""" """
#case_float_ab_constants = None # case_float_ab_constants = None
case_float_gemm = """ case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(%(_x)s);
...@@ -856,7 +856,7 @@ class GemmRelated(Op): ...@@ -856,7 +856,7 @@ class GemmRelated(Op):
{ {
""" """
#case_double_ab_constants = None # case_double_ab_constants = None
case_double_gemm = """ case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(%(_x)s);
...@@ -1030,8 +1030,8 @@ class Gemm(GemmRelated): ...@@ -1030,8 +1030,8 @@ class Gemm(GemmRelated):
raise TypeError(Gemm.E_mixed, raise TypeError(Gemm.E_mixed,
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype)) (z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
if (not z.dtype.startswith('float') if (not z.dtype.startswith('float') and
and not z.dtype.startswith('complex')): not z.dtype.startswith('complex')):
raise TypeError(Gemm.E_float, (z.dtype)) raise TypeError(Gemm.E_float, (z.dtype))
output = z.type() output = z.type()
...@@ -1173,7 +1173,7 @@ class Gemm(GemmRelated): ...@@ -1173,7 +1173,7 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp _z, _a, _x, _y, _b = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if not config.blas.ldflags: if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, return super(Gemm, self).c_code(node, name,
...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None): ...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None):
else: else:
retval = True retval = True
return node.owner \ return (node.owner and
and node.owner.op == op \ node.owner.op == op and
and retval retval)
def _as_scalar(res, dtype=None): def _as_scalar(res, dtype=None):
...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None): ...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res): def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 2 \ res.type.ndim == 2 and
and res.type.broadcastable[0] == False \ res.type.broadcastable[0] == False and
and res.type.broadcastable[1] == False # cope with tuple vs. list res.type.broadcastable[1] == False) # cope with tuple vs. list
def _is_real_vector(res): def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 1 \ res.type.ndim == 1 and
and res.type.broadcastable[0] == False res.type.broadcastable[0] == False)
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer): ...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub, (theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))): theano.scalar.Neg, theano.scalar.Mul))):
continue continue
if not node in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# This mean that we already removed this node from # This mean that we already removed this node from
# the graph # the graph
continue continue
...@@ -1725,7 +1725,7 @@ class Dot22(GemmRelated): ...@@ -1725,7 +1725,7 @@ class Dot22(GemmRelated):
_x, _y = inp _x, _y = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), return super(Dot22, self).c_code(node, name, (_x, _y),
...@@ -1898,8 +1898,7 @@ blas_optdb.register('gemm_optimizer', ...@@ -1898,8 +1898,7 @@ blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([ EquilibriumOptimizer([local_gemm_to_gemv,
local_gemm_to_gemv,
local_gemm_to_ger, local_gemm_to_ger,
local_dot22_to_ger_or_gemv, local_dot22_to_ger_or_gemv,
local_dimshuffle_lift], local_dimshuffle_lift],
...@@ -1938,8 +1937,8 @@ class Dot22Scalar(GemmRelated): ...@@ -1938,8 +1937,8 @@ class Dot22Scalar(GemmRelated):
raise TypeError('Dot22Scalar requires matching dtypes', raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype)) (a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float') if (not a.dtype.startswith('float') and
and not a.dtype.startswith('complex')): not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args', raise TypeError('Dot22Scalar requires float or complex args',
a.dtype) a.dtype)
...@@ -1992,7 +1991,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1992,7 +1991,7 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast( if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype) x.type.dtype, d.type.dtype) == d.type.dtype):
== d.type.dtype):
scalar_idx = i scalar_idx = i
break break
...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run') 11, 'fast_run')
#from opt import register_specialize, register_canonicalize # from opt import register_specialize, register_canonicalize
#@register_specialize # @register_specialize
@local_optimizer([T.sub, T.add]) @local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
......
...@@ -58,7 +58,6 @@ whitelist_flake8 = [ ...@@ -58,7 +58,6 @@ whitelist_flake8 = [
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/blas.py",
"tensor/extra_ops.py", "tensor/extra_ops.py",
"tensor/nlinalg.py", "tensor/nlinalg.py",
"tensor/blas_c.py", "tensor/blas_c.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论