提交 94aab099 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fix for new name of scipy.linalg.blas.fblas

Also cleaned up duplicated / useless code at the same time. This fixes gh-1144.
上级 7b8d668f
......@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
from theano.gof.python25 import all, any
import theano.scalar
import basic as T
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift
......@@ -150,16 +150,22 @@ _logger = logging.getLogger('theano.tensor.blas')
try:
import scipy.linalg.blas
_have_fblas = True
have_fblas = True
try:
fblas = scipy.linalg.blas.fblas
except AttributeError:
# In more recent versions of Scipy, `scipy.linalg.blas.fblas` has been
# replaced with `scipy.linalg.blas`.
fblas = scipy.linalg.blas
_blas_gemv_fns = {
numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv,
numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv,
}
except ImportError, e:
_have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. '
have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas. '
'Falling back on slower implementations (%s)', str(e))
......@@ -216,7 +222,7 @@ class Gemv(Op):
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
if have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
......
......@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
import numpy
from blas import Ger, ger, ger_destructive
from blas import blas_optdb, optdb,local_optimizer
from theano.tensor.blas import Ger, ger, ger_destructive, have_fblas
from theano.tensor.blas import blas_optdb, optdb,local_optimizer
from theano.tensor.opt import in2out
try:
import scipy.linalg.blas
have_fblas = True
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
if have_fblas:
from theano.tensor.blas import fblas
_blas_ger_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sger,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dger,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgeru,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgeru,
numpy.dtype('float32'): fblas.sger,
numpy.dtype('float64'): fblas.dger,
numpy.dtype('complex64'): fblas.cgeru,
numpy.dtype('complex128'): fblas.zgeru,
}
except ImportError, e:
have_fblas = False
class ScipyGer(Ger):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论