提交 55f6a181 authored 作者: nouiz's avatar nouiz

Merge pull request #1150 from delallea/bugfix

Fix for new name of scipy.linalg.blas.fblas
...@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint ...@@ -142,7 +142,7 @@ from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.opt import local_dimshuffle_lift from theano.tensor.opt import local_dimshuffle_lift
...@@ -150,16 +150,23 @@ _logger = logging.getLogger('theano.tensor.blas') ...@@ -150,16 +150,23 @@ _logger = logging.getLogger('theano.tensor.blas')
try: try:
import scipy.linalg.blas import scipy.linalg.blas
_have_fblas = True have_fblas = True
try:
fblas = scipy.linalg.blas.fblas
except AttributeError:
# A change merged in Scipy development version on 2012-12-02 replaced
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas
_blas_gemv_fns = { _blas_gemv_fns = {
numpy.dtype('float32'): scipy.linalg.blas.fblas.sgemv, numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): scipy.linalg.blas.fblas.dgemv, numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): scipy.linalg.blas.fblas.cgemv, numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): scipy.linalg.blas.fblas.zgemv, numpy.dtype('complex128'): fblas.zgemv,
} }
except ImportError, e: except ImportError, e:
_have_fblas = False have_fblas = False
_logger.warning('Failed to import scipy.linalg.blas.fblas. ' _logger.warning('Failed to import scipy.linalg.blas. '
'Falling back on slower implementations (%s)', str(e)) 'Falling back on slower implementations (%s)', str(e))
...@@ -216,7 +223,7 @@ class Gemv(Op): ...@@ -216,7 +223,7 @@ class Gemv(Op):
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if _have_fblas and y.shape[0] != 0 and x.shape[0] != 0: if have_fblas and y.shape[0] != 0 and x.shape[0] != 0:
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
......
...@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings. ...@@ -3,28 +3,20 @@ Implementations of BLAS Ops based on scipy's BLAS bindings.
""" """
import numpy import numpy
from blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive, have_fblas
from blas import blas_optdb, optdb,local_optimizer from theano.tensor.blas import blas_optdb, optdb,local_optimizer
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
try:
import scipy.linalg.blas if have_fblas:
have_fblas = True from theano.tensor.blas import fblas
_blas_gemv_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sgemv,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dgemv,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgemv,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgemv,
}
_blas_ger_fns = { _blas_ger_fns = {
numpy.dtype('float32'):scipy.linalg.blas.fblas.sger, numpy.dtype('float32'): fblas.sger,
numpy.dtype('float64'):scipy.linalg.blas.fblas.dger, numpy.dtype('float64'): fblas.dger,
numpy.dtype('complex64'):scipy.linalg.blas.fblas.cgeru, numpy.dtype('complex64'): fblas.cgeru,
numpy.dtype('complex128'):scipy.linalg.blas.fblas.zgeru, numpy.dtype('complex128'): fblas.zgeru,
} }
except ImportError, e:
have_fblas = False
class ScipyGer(Ger): class ScipyGer(Ger):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论