提交 e8c9bd36 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix CGemV with empty A

上级 abedb7fb
...@@ -417,20 +417,6 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N ...@@ -417,20 +417,6 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
} }
} }
if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}
{ {
int NA0 = PyArray_DIMS(%(A)s)[0]; int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1]; int NA1 = PyArray_DIMS(%(A)s)[1];
...@@ -439,6 +425,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N ...@@ -439,6 +425,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
{ {
// Non-empty A matrix // Non-empty A matrix
if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyArray_FILLWBYTE(%(z)s, 0);
}
/* In the case where A is actually a row or column matrix, /* In the case where A is actually a row or column matrix,
* the strides corresponding to the dummy dimension don't matter, * the strides corresponding to the dummy dimension don't matter,
* but BLAS requires these to be no smaller than the number of elements in the array. * but BLAS requires these to be no smaller than the number of elements in the array.
...@@ -567,6 +564,18 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N ...@@ -567,6 +564,18 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
"A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;"); "A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
%(fail)s %(fail)s
} }
} else
{
// Empty A matrix, just scale y by beta
if (dbeta != 1.0)
{
npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
for (npy_intp i = 0; i < NA0; ++i)
{
z_data[i * Sz] = (dbeta == 0.0) ? 0 : z_data[i * Sz] * dbeta;
}
}
} }
} }
""" """
...@@ -598,7 +607,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -598,7 +607,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (17, blas_header_version(), must_initialize_y_gemv()) return (18, blas_header_version(), must_initialize_y_gemv())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
......
...@@ -8,7 +8,15 @@ import pytensor.tensor as pt ...@@ -8,7 +8,15 @@ import pytensor.tensor as pt
from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector from pytensor.tensor.type import (
dmatrix,
dscalar,
dvector,
matrix,
scalar,
tensor,
vector,
)
from tests import unittest_tools from tests import unittest_tools
from tests.tensor.test_blas import BaseGemv, TestBlasStrides from tests.tensor.test_blas import BaseGemv, TestBlasStrides
from tests.unittest_tools import OptimizationTestMixin from tests.unittest_tools import OptimizationTestMixin
...@@ -143,19 +151,21 @@ class TestCGemv(OptimizationTestMixin): ...@@ -143,19 +151,21 @@ class TestCGemv(OptimizationTestMixin):
def test_nan_beta_0(self, inplace): def test_nan_beta_0(self, inplace):
mode = self.mode.including() mode = self.mode.including()
mode.check_isfinite = False mode.check_isfinite = False
beta = self.a.type("beta")
f = pytensor.function( f = pytensor.function(
[self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a], [self.A, self.x, pytensor.In(self.y, mutable=inplace), beta],
self.a * self.y + pt.dot(self.A, self.x), beta * self.y + pt.dot(self.A, self.x),
mode=mode, mode=mode,
) )
[node] = f.maker.fgraph.apply_nodes [node] = f.maker.fgraph.apply_nodes
assert isinstance(node.op, CGemv) and node.op.inplace == inplace assert isinstance(node.op, CGemv) and node.op.inplace == inplace
for rows in (3, 1): for rows in (3, 1, 0):
Aval = np.ones((rows, 1), dtype=self.dtype) for cols in (1, 0):
xval = np.ones((1,), dtype=self.dtype) Aval = np.ones((rows, cols), dtype=self.dtype)
yval = np.full((rows,), np.nan, dtype=self.dtype) xval = np.ones((cols,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0) yval = np.full((rows,), np.nan, dtype=self.dtype)
assert not np.isnan(zval).any() zval = f(Aval, xval, yval, beta=0)
assert not np.isnan(zval).any(), f"{rows=}, {cols=}"
def test_optimizations_vm(self): def test_optimizations_vm(self):
skip_if_blas_ldflags_empty() skip_if_blas_ldflags_empty()
...@@ -294,6 +304,26 @@ class TestCGemv(OptimizationTestMixin): ...@@ -294,6 +304,26 @@ class TestCGemv(OptimizationTestMixin):
== 2 == 2
) )
def test_empty_A(self):
A = dmatrix("A")
x = dvector("x")
y = dvector("y")
alpha = 1.0
beta = dscalar("beta")
gemv = CGemv(inplace=True)(y, alpha, A, x, beta)
fn = pytensor.function(
[A, x, y, beta],
gemv,
accept_inplace=True,
)
test_A = np.empty((10, 0))
test_x = np.empty((0,))
test_y = np.random.random((10,))
for test_beta in [0.0, 1.0, 2.0]:
out = fn(test_A, test_x, test_y.copy(), test_beta)
expected = test_beta * test_y
np.testing.assert_allclose(out, expected)
class TestCGemvFloat32(BaseGemv, OptimizationTestMixin): class TestCGemvFloat32(BaseGemv, OptimizationTestMixin):
mode = mode_blas_opt mode = mode_blas_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论