提交 6d236f14 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix GEMV dot case with empty output and beta 0

Bug introduced in 709f745c
上级 148477cb
...@@ -113,23 +113,22 @@ from pytensor.tensor.type import DenseTensorType, tensor ...@@ -113,23 +113,22 @@ from pytensor.tensor.type import DenseTensorType, tensor
_logger = logging.getLogger("pytensor.tensor.blas") _logger = logging.getLogger("pytensor.tensor.blas")
# If check_init_y() == True we need to initialize y when beta == 0. def must_initialize_y_gemv():
def check_init_y(): # Check whether Scipy GEMV could output nan if y in not initialized
# TODO: What is going on here?
from scipy.linalg.blas import get_blas_funcs from scipy.linalg.blas import get_blas_funcs
if check_init_y._result is None: if must_initialize_y_gemv._result is None:
y = float("NaN") * np.ones((2,)) y = np.full((2,), np.nan)
x = np.ones((2,)) x = np.ones((2,))
A = np.ones((2, 2)) A = np.ones((2, 2))
gemv = get_blas_funcs("gemv", dtype=y.dtype) gemv = get_blas_funcs("gemv", dtype=y.dtype)
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any() must_initialize_y_gemv._result = np.isnan(y).any()
return check_init_y._result return must_initialize_y_gemv._result
check_init_y._result = None # type: ignore must_initialize_y_gemv._result = None # type: ignore
class Gemv(Op): class Gemv(Op):
...@@ -197,7 +196,13 @@ class Gemv(Op): ...@@ -197,7 +196,13 @@ class Gemv(Op):
f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}" f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}"
) )
if beta == 0 and check_init_y(): if beta == 0 and must_initialize_y_gemv():
# Most BLAS implementations of GEMV ignore y=nan when beta=0
# PyTensor considers that the correct behavior,
# and even exploits it to avoid copying or initializing outputs.
# By deciding to exploit this, however, it becomes our responsibility
# to ensure the behavior even in the rare cases BLAS deviates,
# or users will get errors, even for graphs that had no nan to begin with.
y.fill(0) y.fill(0)
# Here I suppose that A is in c order. If we don't make it # Here I suppose that A is in c order. If we don't make it
......
...@@ -336,11 +336,12 @@ cger_no_inplace = CGer(False) ...@@ -336,11 +336,12 @@ cger_no_inplace = CGer(False)
# ##### ####### ####### # ##### ####### #######
def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=None): def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=None):
""" """
z <- beta * y + alpha * dot(A, x) z <- beta * y + alpha * dot(A, x)
where A is a matrix, y and x are vectors (ergo z is vector) where A is a matrix, y and x are vectors (ergo z is vector)
z = y if inplace else y.copy()
""" """
code = """ code = """
...@@ -400,17 +401,11 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -400,17 +401,11 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
} }
if (dbeta != 0) if (dbeta != 0)
{ {
// If dbeta is zero, we avoid doing the copy
if (PyArray_CopyInto(%(z)s, %(y)s) != 0) { if (PyArray_CopyInto(%(z)s, %(y)s) != 0) {
%(fail)s %(fail)s
} }
} }
else if (%(force_init_beta)d)
{
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}
} }
else else
{ {
...@@ -422,6 +417,20 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -422,6 +417,20 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
} }
} }
if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}
{ {
int NA0 = PyArray_DIMS(%(A)s)[0]; int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1]; int NA1 = PyArray_DIMS(%(A)s)[1];
...@@ -491,13 +500,13 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non ...@@ -491,13 +500,13 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
if (is_float) if (is_float)
{ {
z_data[0] *= fbeta; z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.f;
z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1, z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
(float*)x_data, &Sx); (float*)x_data, &Sx);
} }
else else
{ {
z_data[0] *= dbeta; z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.;
z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1, z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
(double*)x_data, &Sx); (double*)x_data, &Sx);
} }
...@@ -583,21 +592,21 @@ class CGemv(BaseBLAS, Gemv): ...@@ -583,21 +592,21 @@ class CGemv(BaseBLAS, Gemv):
alpha, alpha,
beta, beta,
fail=sub["fail"], fail=sub["fail"],
force_init_beta=check_force_gemv_init(), must_initialize_y=must_initialize_y_gemv(),
params=sub["params"], params=sub["params"],
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (16, blas_header_version(), check_force_gemv_init()) return (17, blas_header_version(), must_initialize_y_gemv())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init(): def must_initialize_y_gemv():
if check_force_gemv_init._force_init_beta is None: if must_initialize_y_gemv._force_init_beta is None:
from pytensor.link.c.cmodule import GCC_compiler from pytensor.link.c.cmodule import GCC_compiler
""" """
...@@ -643,13 +652,13 @@ int main() { ...@@ -643,13 +652,13 @@ int main() {
) )
if res: if res:
if res[0]: if res[0]:
check_force_gemv_init._force_init_beta = res[1] must_initialize_y_gemv._force_init_beta = res[1]
else: else:
check_force_gemv_init._force_init_beta = False must_initialize_y_gemv._force_init_beta = False
else: else:
check_force_gemv_init._force_init_beta = False must_initialize_y_gemv._force_init_beta = False
return check_force_gemv_init._force_init_beta return must_initialize_y_gemv._force_init_beta
check_force_gemv_init._force_init_beta = None must_initialize_y_gemv._force_init_beta = None
...@@ -700,7 +700,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ...@@ -700,7 +700,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
new_out = [rval] new_out = [rval]
elif xb[0] and yb[1]: elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot # x and y are both vectors so this qualifies for a sdot / ddot
# TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
xv = x.dimshuffle(1) xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(1) zeros = ptb.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero) rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
......
...@@ -7,7 +7,7 @@ import pytensor ...@@ -7,7 +7,7 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.blas_scipy import ScipyGer from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
from tests import unittest_tools from tests import unittest_tools
...@@ -130,29 +130,33 @@ class TestCGemv(OptimizationTestMixin): ...@@ -130,29 +130,33 @@ class TestCGemv(OptimizationTestMixin):
self.dtype = dtype self.dtype = dtype
self.mode = pytensor.compile.get_default_mode().including("fast_run") self.mode = pytensor.compile.get_default_mode().including("fast_run")
# matrix # matrix
self.A = tensor(dtype=dtype, shape=(None, None)) self.A = tensor("A", dtype=dtype, shape=(None, None))
self.Aval = np.ones((2, 3), dtype=dtype) self.Aval = np.ones((2, 3), dtype=dtype)
# vector # vector
self.x = tensor(dtype=dtype, shape=(None,)) self.x = tensor("x", dtype=dtype, shape=(None,))
self.y = tensor(dtype=dtype, shape=(None,)) self.y = tensor("y", dtype=dtype, shape=(None,))
self.xval = np.asarray([1, 2], dtype=dtype) self.xval = np.asarray([1, 2], dtype=dtype)
self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype) self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype)
# scalar # scalar
self.a = tensor(dtype=dtype, shape=()) self.a = tensor("a", dtype=dtype, shape=())
def test_nan_beta_0(self): @pytest.mark.parametrize("inplace", [True, False])
def test_nan_beta_0(self, inplace):
mode = self.mode.including() mode = self.mode.including()
mode.check_isfinite = False mode.check_isfinite = False
f = pytensor.function( f = pytensor.function(
[self.A, self.x, self.y, self.a], [self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a],
self.a * self.y + pt.dot(self.A, self.x), self.a * self.y + pt.dot(self.A, self.x),
mode=mode, mode=mode,
) )
Aval = np.ones((3, 1), dtype=self.dtype) [node] = f.maker.fgraph.apply_nodes
assert isinstance(node.op, CGemv) and node.op.inplace == inplace
for rows in (3, 1):
Aval = np.ones((rows, 1), dtype=self.dtype)
xval = np.ones((1,), dtype=self.dtype) xval = np.ones((1,), dtype=self.dtype)
yval = float("NaN") * np.ones((3,), dtype=self.dtype) yval = np.full((rows,), np.nan, dtype=self.dtype)
zval = f(Aval, xval, yval, 0) zval = f(Aval, xval, yval, 0)
assert not np.isnan(zval).any() assert not np.isnan(zval).any()
...@@ -191,8 +195,10 @@ class TestCGemv(OptimizationTestMixin): ...@@ -191,8 +195,10 @@ class TestCGemv(OptimizationTestMixin):
np.dot(self.Aval[::-1, ::-1], self.yval), np.dot(self.Aval[::-1, ::-1], self.yval),
) )
def test_force_gemv_init(self): def test_must_initialize_y_gemv(self):
if check_force_gemv_init(): if must_initialize_y_gemv():
# FIME: This warn should be emitted by the function if we find it relevant
# Not in a test that doesn't care about the outcome either way
warn( warn(
"WARNING: The current BLAS requires PyTensor to initialize" "WARNING: The current BLAS requires PyTensor to initialize"
" memory for some GEMV calls which will result in a minor" " memory for some GEMV calls which will result in a minor"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论