提交 55f6a181 authored 作者: nouiz's avatar nouiz

Merge pull request #1150 from delallea/bugfix

Fix for new name of scipy.linalg.blas.fblas
......@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
from theano.gof.python25 import all, any
import theano.scalar
import basic as T
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift
......@@ -150,16 +150,23 @@ _logger = logging.getLogger('theano.tensor.blas')
try:
import scipy.linalg.blas
_have_fblas = True
have_fblas = True
try:
fblas = scipy.linalg.blas.fblas
except AttributeError:
# A change merged in Scipy development version on 2012-12-02 replaced
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas
_blas_gemv_fns = {
numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv,
numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv,
}
except ImportError, e:
_have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. '
have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas. '
'Falling back on slower implementations (%s)', str(e))
......@@ -216,7 +223,7 @@ class Gemv(Op):
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
if have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
......
......@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
import numpy
from blas import Ger, ger, ger_destructive
from blas import blas_optdb, optdb,local_optimizer
from theano.tensor.blas import Ger, ger, ger_destructive, have_fblas
from theano.tensor.blas import blas_optdb, optdb,local_optimizer
from theano.tensor.opt import in2out
try:
import scipy.linalg.blas
have_fblas = True
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
if have_fblas:
from theano.tensor.blas import fblas
_blas_ger_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sger,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dger,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgeru,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgeru,
numpy.dtype('float32'): fblas.sger,
numpy.dtype('float64'): fblas.dger,
numpy.dtype('complex64'): fblas.cgeru,
numpy.dtype('complex128'): fblas.zgeru,
}
except ImportError, e:
have_fblas = False
class ScipyGer(Ger):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论