提交 a9e8fd78 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/blas_scipy.py

上级 932dd94f
...@@ -12,11 +12,11 @@ from theano.tensor.opt import in2out ...@@ -12,11 +12,11 @@ from theano.tensor.opt import in2out
if have_fblas: if have_fblas:
from theano.tensor.blas import fblas from theano.tensor.blas import fblas
_blas_ger_fns = { _blas_ger_fns = {
numpy.dtype('float32'): fblas.sger, numpy.dtype('float32'): fblas.sger,
numpy.dtype('float64'): fblas.dger, numpy.dtype('float64'): fblas.dger,
numpy.dtype('complex64'): fblas.cgeru, numpy.dtype('complex64'): fblas.cgeru,
numpy.dtype('complex128'): fblas.zgeru, numpy.dtype('complex128'): fblas.zgeru,
} }
class ScipyGer(Ger): class ScipyGer(Ger):
...@@ -47,10 +47,10 @@ class ScipyGer(Ger): ...@@ -47,10 +47,10 @@ class ScipyGer(Ger):
A = A.copy() A = A.copy()
elif A.flags['C_CONTIGUOUS']: elif A.flags['C_CONTIGUOUS']:
A = local_ger(calpha[0], cy[0], cx[0], a=A.T, A = local_ger(calpha[0], cy[0], cx[0], a=A.T,
overwrite_a=int(self.destructive)).T overwrite_a=int(self.destructive)).T
else: else:
A = local_ger(calpha[0], cx[0], cy[0], a=A, A = local_ger(calpha[0], cx[0], cy[0], a=A,
overwrite_a=int(self.destructive)) overwrite_a=int(self.destructive))
cZ[0] = A cZ[0] = A
for o in node_output_compute: for o in node_output_compute:
o[0] = True o[0] = True
...@@ -87,10 +87,10 @@ if have_fblas: ...@@ -87,10 +87,10 @@ if have_fblas:
# precedence. Once the original Ger is replaced, then these optimizations # precedence. Once the original Ger is replaced, then these optimizations
# have no effect. # have no effect.
blas_optdb.register('scipy_blas', blas_optdb.register('scipy_blas',
use_scipy_blas, use_scipy_blas,
100, 'fast_run') 100, 'fast_run')
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register('make_scipy_blas_destructive', optdb.register('make_scipy_blas_destructive',
make_scipy_blas_destructive, make_scipy_blas_destructive,
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
...@@ -58,7 +58,6 @@ whitelist_flake8 = [ ...@@ -58,7 +58,6 @@ whitelist_flake8 = [
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/blas_scipy.py",
"tensor/tests/test_subtensor.py", "tensor/tests/test_subtensor.py",
"tensor/tests/test_utils.py", "tensor/tests/test_utils.py",
"tensor/tests/test_nlinalg.py", "tensor/tests/test_nlinalg.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论