提交 79eee675 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Scipy blas is no longer optional

上级 6bdfbae8
......@@ -108,50 +108,19 @@ from pytensor.tensor.type import DenseTensorType, tensor
_logger = logging.getLogger("pytensor.tensor.blas")
try:
import scipy.linalg.blas
have_fblas = True
try:
fblas = scipy.linalg.blas.fblas
except AttributeError:
# A change merged in Scipy development version on 2012-12-02 replaced
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas
_blas_gemv_fns = {
np.dtype("float32"): fblas.sgemv,
np.dtype("float64"): fblas.dgemv,
np.dtype("complex64"): fblas.cgemv,
np.dtype("complex128"): fblas.zgemv,
}
except ImportError as e:
have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer
# when config.blas__ldflags is defined. So we don't need a
# warning in that case.
if not config.blas__ldflags:
_logger.warning(
"Failed to import scipy.linalg.blas, and "
"PyTensor flag blas__ldflags is empty. "
"Falling back on slower implementations for "
"dot(matrix, vector), dot(vector, matrix) and "
f"dot(vector, vector) ({e!s})"
)
# If check_init_y() == True we need to initialize y when beta == 0.
def check_init_y():
# TODO: What is going on here?
from scipy.linalg.blas import get_blas_funcs
if check_init_y._result is None:
if not have_fblas: # pragma: no cover
check_init_y._result = False
else:
y = float("NaN") * np.ones((2,))
x = np.ones((2,))
A = np.ones((2, 2))
gemv = _blas_gemv_fns[y.dtype]
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any()
y = float("NaN") * np.ones((2,))
x = np.ones((2,))
A = np.ones((2, 2))
gemv = get_blas_funcs("gemv", dtype=y.dtype)
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any()
return check_init_y._result
......@@ -208,14 +177,15 @@ class Gemv(Op):
return Apply(self, inputs, [y.type()])
def perform(self, node, inputs, out_storage):
from scipy.linalg.blas import get_blas_funcs
y, alpha, A, x, beta = inputs
if (
have_fblas
and y.shape[0] != 0
y.shape[0] != 0
and x.shape[0] != 0
and y.dtype in _blas_gemv_fns
and y.dtype in {"float32", "float64", "complex64", "complex128"}
):
gemv = _blas_gemv_fns[y.dtype]
gemv = get_blas_funcs("gemv", dtype=y.dtype)
if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]:
raise ValueError(
......
......@@ -2,30 +2,19 @@
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
import numpy as np
from pytensor.tensor.blas import Ger, have_fblas
if have_fblas:
from pytensor.tensor.blas import fblas
_blas_ger_fns = {
np.dtype("float32"): fblas.sger,
np.dtype("float64"): fblas.dger,
np.dtype("complex64"): fblas.cgeru,
np.dtype("complex128"): fblas.zgeru,
}
from pytensor.tensor.blas import Ger
class ScipyGer(Ger):
def perform(self, node, inputs, output_storage):
from scipy.linalg.blas import get_blas_funcs
cA, calpha, cx, cy = inputs
(cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = _blas_ger_fns[cA.dtype]
local_ger = get_blas_funcs("ger", dtype=cA.dtype)
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
......
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.blas import ger, ger_destructive, have_fblas
from pytensor.tensor.blas import ger, ger_destructive
from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
......@@ -19,19 +19,19 @@ def make_ger_destructive(fgraph, node):
use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)
if have_fblas:
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks, but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=50.2,
)
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks [citation needed], but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=50.2,
)
import pickle
import numpy as np
import pytest
import pytensor
from pytensor import tensor as pt
......@@ -12,7 +11,6 @@ from tests.tensor.test_blas import TestBlasStrides, gemm_no_inplace
from tests.unittest_tools import OptimizationTestMixin
@pytest.mark.skipif(not pytensor.tensor.blas_scipy.have_fblas, reason="fblas needed")
class TestScipyGer(OptimizationTestMixin):
def setup_method(self):
self.mode = pytensor.compile.get_default_mode()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论