提交 94aab099 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix for new name of scipy.linalg.blas.fblas

Also cleaned up duplicated / useless code at the same time. This fixes gh-1144.
上级 7b8d668f
...@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint ...@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift from theano.tensor.opt import local_dimshuffle_lift
...@@ -150,16 +150,22 @@ _logger = logging.getLogger('theano.tensor.blas') ...@@ -150,16 +150,22 @@ _logger = logging.getLogger('theano.tensor.blas')
try: try:
import scipy.linalg.blas import scipy.linalg.blas
_have_fblas = True have_fblas = True
try:
fblas = scipy.linalg.blas.fblas
except AttributeError:
# In more recent versions of Scipy, `scipy.linalg.blas.fblas` has been
# replaced with `scipy.linalg.blas`.
fblas = scipy.linalg.blas
_blas_gemv_fns = { _blas_gemv_fns = {
numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv, numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv, numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv, numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv, numpy.dtype('complex128'): fblas.zgemv,
} }
except ImportError, e: except ImportError, e:
_have_fblas = False have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. ' _logger.warning('Failed to import scipy.linalg.blas. '
'Falling back on slower implementations (%s)', str(e)) 'Falling back on slower implementations (%s)', str(e))
...@@ -216,7 +222,7 @@ class Gemv(Op): ...@@ -216,7 +222,7 @@ class Gemv(Op):
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0: if have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
......
...@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings. ...@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings.
""" """
import numpy import numpy
from blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive, have_fblas
from blas import blas_optdb, optdb,local_optimizer from theano.tensor.blas import blas_optdb, optdb,local_optimizer
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
try:
import scipy.linalg.blas if have_fblas:
have_fblas = True from theano.tensor.blas import fblas
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
_blas_ger_fns = { _blas_ger_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sger, numpy.dtype('float32'): fblas.sger,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dger, numpy.dtype('float64'): fblas.dger,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgeru, numpy.dtype('complex64'): fblas.cgeru,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgeru, numpy.dtype('complex128'): fblas.zgeru,
} }
except ImportError, e:
have_fblas = False
class ScipyGer(Ger): class ScipyGer(Ger):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论