提交 afcb5350 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/blas.py; two E left

上级 715c215a
...@@ -164,7 +164,7 @@ def default_blas_ldflags(): ...@@ -164,7 +164,7 @@ def default_blas_ldflags():
global numpy global numpy
try: try:
if (hasattr(numpy.distutils, '__config__') and if (hasattr(numpy.distutils, '__config__') and
numpy.distutils.__config__): numpy.distutils.__config__):
# If the old private interface is available use it as it # If the old private interface is available use it as it
# don't print information to the user. # don't print information to the user.
blas_info = numpy.distutils.__config__.blas_opt_info blas_info = numpy.distutils.__config__.blas_opt_info
...@@ -319,8 +319,8 @@ SOMEPATH/Canopy_64bit/User/lib/python2.7/site-packages/numpy/distutils/system_in ...@@ -319,8 +319,8 @@ SOMEPATH/Canopy_64bit/User/lib/python2.7/site-packages/numpy/distutils/system_in
AddConfigVar('blas.ldflags', AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation", "lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags)) StrParam(default_blas_ldflags))
try: try:
...@@ -333,12 +333,10 @@ try: ...@@ -333,12 +333,10 @@ try:
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`. # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358 # See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas fblas = scipy.linalg.blas
_blas_gemv_fns = { _blas_gemv_fns = {numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float32'): fblas.sgemv, numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('float64'): fblas.dgemv, numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex64'): fblas.cgemv, numpy.dtype('complex128'): fblas.zgemv}
numpy.dtype('complex128'): fblas.zgemv,
}
except ImportError as e: except ImportError as e:
have_fblas = False have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer # This is used in Gemv and ScipyGer. We use CGemv and CGer
...@@ -400,8 +398,8 @@ class Gemv(Op): ...@@ -400,8 +398,8 @@ class Gemv(Op):
# The following is not grounds for error because as long as # The following is not grounds for error because as long as
# sizes are 1 at time of perform() there is no problem # sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]: # if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A', # raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type)) # (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
...@@ -411,9 +409,10 @@ class Gemv(Op): ...@@ -411,9 +409,10 @@ class Gemv(Op):
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv ' raise ValueError(
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s ' 'Incompatible shapes for gemv '
% (y.shape, A.shape, x.shape)) '(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape))
# Here I suppose that A is in c order. If we don't make it # Here I suppose that A is in c order. If we don't make it
# explicitly as fortran order, scipy 0.7.2 seam to create # explicitly as fortran order, scipy 0.7.2 seam to create
...@@ -479,7 +478,7 @@ class Ger(Op): ...@@ -479,7 +478,7 @@ class Ger(Op):
alpha = T.as_tensor_variable(alpha) alpha = T.as_tensor_variable(alpha)
if len(set([A.dtype, alpha.dtype, x.dtype, y.dtype])) != 1: if len(set([A.dtype, alpha.dtype, x.dtype, y.dtype])) != 1:
raise TypeError('ger requires matching dtypes', raise TypeError('ger requires matching dtypes',
(A.dtype, alpha.dtype, x.dtype, y.dtype)) (A.dtype, alpha.dtype, x.dtype, y.dtype))
if alpha.ndim != 0: if alpha.ndim != 0:
raise TypeError('ger requires scalar alpha', alpha.type) raise TypeError('ger requires scalar alpha', alpha.type)
if A.ndim != 2: if A.ndim != 2:
...@@ -567,13 +566,14 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): ...@@ -567,13 +566,14 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
for d in dirs: for d in dirs:
for f in os.listdir(d): for f in os.listdir(d):
if (f.endswith('.so') or f.endswith('.dylib') or if (f.endswith('.so') or f.endswith('.dylib') or
f.endswith('.dll')): f.endswith('.dll')):
if any([f.find(ll) >= 0 for ll in l]): if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True found_dyn = True
if not found_dyn and dirs: if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the " _logger.warning(
"library_dir of the library we use for blas. If you use " "We did not found a dynamic library into the "
"ATLAS, make sure to compile it with dynamics library.") "library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.")
for t in ldflags_str.split(): for t in ldflags_str.split():
# Remove extra quote. # Remove extra quote.
...@@ -644,7 +644,7 @@ class GemmRelated(Op): ...@@ -644,7 +644,7 @@ class GemmRelated(Op):
return ldflags() return ldflags()
# code_cache_version is built by subclasses from # code_cache_version is built by subclasses from
# build_gemm_version # build_gemm_version
def c_compile_args(self): def c_compile_args(self):
return ldflags(libs=False, flags=True) return ldflags(libs=False, flags=True)
...@@ -673,7 +673,7 @@ class GemmRelated(Op): ...@@ -673,7 +673,7 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
""" """
#setup_z_Nz_Sz = None # setup_z_Nz_Sz = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) { if (PyArray_NDIM(%(_x)s) != 2) {
...@@ -823,7 +823,7 @@ class GemmRelated(Op): ...@@ -823,7 +823,7 @@ class GemmRelated(Op):
{ {
""" """
#case_float_ab_constants = None # case_float_ab_constants = None
case_float_gemm = """ case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(%(_x)s);
...@@ -856,7 +856,7 @@ class GemmRelated(Op): ...@@ -856,7 +856,7 @@ class GemmRelated(Op):
{ {
""" """
#case_double_ab_constants = None # case_double_ab_constants = None
case_double_gemm = """ case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(%(_x)s);
...@@ -1028,10 +1028,10 @@ class Gemm(GemmRelated): ...@@ -1028,10 +1028,10 @@ class Gemm(GemmRelated):
if not (z.dtype == a.dtype == x.dtype == y.dtype == b.dtype): if not (z.dtype == a.dtype == x.dtype == y.dtype == b.dtype):
raise TypeError(Gemm.E_mixed, raise TypeError(Gemm.E_mixed,
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype)) (z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
if (not z.dtype.startswith('float') if (not z.dtype.startswith('float') and
and not z.dtype.startswith('complex')): not z.dtype.startswith('complex')):
raise TypeError(Gemm.E_float, (z.dtype)) raise TypeError(Gemm.E_float, (z.dtype))
output = z.type() output = z.type()
...@@ -1173,8 +1173,8 @@ class Gemm(GemmRelated): ...@@ -1173,8 +1173,8 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp _z, _a, _x, _y, _b = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if not config.blas.ldflags: if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, return super(Gemm, self).c_code(node, name,
(_z, _a, _x, _y, _b), (_zout, ), (_z, _a, _x, _y, _b), (_zout, ),
...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None): ...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None):
else: else:
retval = True retval = True
return node.owner \ return (node.owner and
and node.owner.op == op \ node.owner.op == op and
and retval retval)
def _as_scalar(res, dtype=None): def _as_scalar(res, dtype=None):
...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None): ...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res): def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 2 \ res.type.ndim == 2 and
and res.type.broadcastable[0] == False \ res.type.broadcastable[0] == False and
and res.type.broadcastable[1] == False # cope with tuple vs. list res.type.broadcastable[1] == False) # cope with tuple vs. list
def _is_real_vector(res): def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 1 \ res.type.ndim == 1 and
and res.type.broadcastable[0] == False res.type.broadcastable[0] == False)
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
...@@ -1262,8 +1262,8 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): ...@@ -1262,8 +1262,8 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
# it also might be the case that there is a dimshuffle between the + # it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things. # and the dot22. local_dot_to_dot22 in particular will put in such things.
if (M.owner and isinstance(M.owner.op, T.DimShuffle) and if (M.owner and isinstance(M.owner.op, T.DimShuffle) and
M.owner.inputs[0].owner and M.owner.inputs[0].owner and
isinstance(M.owner.inputs[0].owner.op, Dot22)): isinstance(M.owner.inputs[0].owner.op, Dot22)):
MM = M.owner.inputs[0] MM = M.owner.inputs[0]
if M.owner.op.new_order == (0,): if M.owner.op.new_order == (0,):
# it is making a column MM into a vector # it is making a column MM into a vector
...@@ -1493,7 +1493,7 @@ def _gemm_from_factored_list(lst): ...@@ -1493,7 +1493,7 @@ def _gemm_from_factored_list(lst):
assert len(gemm_of_sM_list) == 1 assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input) add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i, j)] for k, input in enumerate(lst) if k not in (i, j)]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: if len(add_inputs) > 1:
rval = [T.add(*add_inputs)] rval = [T.add(*add_inputs)]
...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer): ...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub, (theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))): theano.scalar.Neg, theano.scalar.Mul))):
continue continue
if not node in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# This mean that we already removed this node from # This mean that we already removed this node from
# the graph # the graph
continue continue
...@@ -1725,8 +1725,8 @@ class Dot22(GemmRelated): ...@@ -1725,8 +1725,8 @@ class Dot22(GemmRelated):
_x, _y = inp _x, _y = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), return super(Dot22, self).c_code(node, name, (_x, _y),
(_zout, ), sub) (_zout, ), sub)
...@@ -1895,17 +1895,16 @@ blas_optdb.register('local_dot_to_dot22', ...@@ -1895,17 +1895,16 @@ blas_optdb.register('local_dot_to_dot22',
in2out(local_dot_to_dot22), in2out(local_dot_to_dot22),
0, 'fast_run', 'fast_compile') 0, 'fast_run', 'fast_compile')
blas_optdb.register('gemm_optimizer', blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([ EquilibriumOptimizer([local_gemm_to_gemv,
local_gemm_to_gemv, local_gemm_to_ger,
local_gemm_to_ger, local_dot22_to_ger_or_gemv,
local_dot22_to_ger_or_gemv, local_dimshuffle_lift],
local_dimshuffle_lift], max_use_ratio=5,
max_use_ratio=5, ignore_newtrees=False),
ignore_newtrees=False), 15, 'fast_run')
15, 'fast_run')
# After destroyhandler(49.5) but before we try to make elemwise things # After destroyhandler(49.5) but before we try to make elemwise things
...@@ -1936,12 +1935,12 @@ class Dot22Scalar(GemmRelated): ...@@ -1936,12 +1935,12 @@ class Dot22Scalar(GemmRelated):
if not (a.dtype == x.dtype == y.dtype): if not (a.dtype == x.dtype == y.dtype):
raise TypeError('Dot22Scalar requires matching dtypes', raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype)) (a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float') if (not a.dtype.startswith('float') and
and not a.dtype.startswith('complex')): not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args', raise TypeError('Dot22Scalar requires float or complex args',
a.dtype) a.dtype)
bz = [x.type.broadcastable[0], y.type.broadcastable[1]] bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)] outputs = [T.tensor(x.type.dtype, bz)]
...@@ -1992,8 +1991,8 @@ class Dot22Scalar(GemmRelated): ...@@ -1992,8 +1991,8 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
(_zout, ), sub) (_zout, ), sub)
...@@ -2051,7 +2050,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2051,7 +2050,7 @@ def local_dot22_to_dot22scalar(node):
# The canonizer should have merged those mul together. # The canonizer should have merged those mul together.
i_mul = [x.owner and x.owner.op == T.mul and i_mul = [x.owner and x.owner.op == T.mul and
any([_as_scalar(x_i, dtype=d.dtype) any([_as_scalar(x_i, dtype=d.dtype)
for x_i in x.owner.inputs]) for x_i in x.owner.inputs])
for x in node.inputs] for x in node.inputs]
if not any(i_mul): if not any(i_mul):
# no scalar in input and no multiplication # no scalar in input and no multiplication
...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast( if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype) x.type.dtype, d.type.dtype) == d.type.dtype):
== d.type.dtype):
scalar_idx = i scalar_idx = i
break break
...@@ -2103,8 +2101,8 @@ def local_dot22_to_dot22scalar(node): ...@@ -2103,8 +2101,8 @@ def local_dot22_to_dot22scalar(node):
break break
if scalar_idx < 0: if scalar_idx < 0:
_logger.info('Not optimizing dot22 with inputs %s %s, as the type ' _logger.info('Not optimizing dot22 with inputs %s %s, as the type '
'of the scalar cannot be upcasted to the matrix type', 'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
return False return False
assert scalar_idx < len(node.inputs) assert scalar_idx < len(node.inputs)
s = node.inputs[scalar_idx] s = node.inputs[scalar_idx]
...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run') 11, 'fast_run')
#from opt import register_specialize, register_canonicalize # from opt import register_specialize, register_canonicalize
#@register_specialize # @register_specialize
@local_optimizer([T.sub, T.add]) @local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
......
...@@ -58,7 +58,6 @@ whitelist_flake8 = [ ...@@ -58,7 +58,6 @@ whitelist_flake8 = [
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/blas.py",
"tensor/extra_ops.py", "tensor/extra_ops.py",
"tensor/nlinalg.py", "tensor/nlinalg.py",
"tensor/blas_c.py", "tensor/blas_c.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论