提交 0babc678 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4168 from abergeron/fix_cgemv_nan

Fix nan handling in output buffer for CGemv
import numpy
from theano import config from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
...@@ -8,7 +6,6 @@ from theano.tensor.blas import blas_optdb, optdb, local_optimizer ...@@ -8,7 +6,6 @@ from theano.tensor.blas import blas_optdb, optdb, local_optimizer
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
import theano.compile
class BaseBLAS(object): class BaseBLAS(object):
...@@ -167,7 +164,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -167,7 +164,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else else
{ {
//fprintf(stderr, "USING A\\n");
if (%(Z)s != %(A)s) if (%(Z)s != %(A)s)
{ {
if (%(Z)s) { Py_DECREF(%(Z)s); } if (%(Z)s) { Py_DECREF(%(Z)s); }
...@@ -253,7 +249,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -253,7 +249,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz0, &Nz1, &alpha, sger_(&Nz0, &Nz1, &alpha,
(float*)x_data, &Sx, (float*)x_data, &Sx,
...@@ -353,65 +348,34 @@ def make_c_ger_destructive(node): ...@@ -353,65 +348,34 @@ def make_c_ger_destructive(node):
# ##### ####### ####### # ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
force_init_beta=False): force_init_beta=False):
""" """
zz <- beta * aa + alpha * dot(xx, yy) z <- beta * y + alpha * dot(A, x)
where xx is a matrix, yy and aa are vectors (ergo zz is vector) where A is a matrix, y and x are vectors (ergo z is vector)
""" """
code = """ code = """
int elemsize ; int elemsize;
float fbeta; float fbeta;
double dbeta; double dbeta;
if (PyArray_NDIM(%(aa)s) != 1) if (PyArray_DIMS(%(A)s)[0] != PyArray_DIMS(%(y)s)[0])
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1");
%(fail)s;
}
if (PyArray_NDIM(%(xx)s) != 2)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2");
%(fail)s;
}
if (PyArray_NDIM(%(yy)s) != 1)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1");
%(fail)s;
}
if (PyArray_NDIM(%(alpha)s) != 0)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0");
%(fail)s;
}
if (PyArray_NDIM(%(beta)s) != 0)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0");
%(fail)s;
}
if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(xx)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(yy)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (PyArray_DIMS(%(xx)s)[0] != PyArray_DIMS(%(aa)s)[0])
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]"); "Shape mismatch: A.shape[0] != y.shape[0]");
%(fail)s; %(fail)s;
} }
if (PyArray_DIMS(%(xx)s)[1] != PyArray_DIMS(%(yy)s)[0]) if (PyArray_DIMS(%(A)s)[1] != PyArray_DIMS(%(x)s)[0])
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]"); "Shape mismatch: A.shape[1] != x.shape[0]");
%(fail)s; %(fail)s;
} }
if (PyArray_DESCR(%(aa)s)->type_num == NPY_DOUBLE) { elemsize = 8; } if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
else if (PyArray_DESCR(%(aa)s)->type_num == NPY_FLOAT) { elemsize = 4;} else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;}
else { else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s; %(fail)s;
...@@ -419,177 +383,117 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, ...@@ -419,177 +383,117 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0]; fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy aa if not destructive // copy y if not destructive
if (!%(destructive)s) if (!%(destructive)s)
{ {
if ((NULL == %(zz)s) if ((NULL == %(z)s)
|| (PyArray_DIMS(%(zz)s)[0] != PyArray_DIMS(%(aa)s)[0])) || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
{ {
Py_XDECREF(%(zz)s); Py_XDECREF(%(z)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1, %(z)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(aa)s), PyArray_TYPE(%(aa)s)); PyArray_DIMS(%(y)s), PyArray_TYPE(%(y)s));
if(!%(zz)s) { if(!%(z)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output"); "failed to alloc gemv output");
%(fail)s %(fail)s
} }
} }
if (%(zz)s == %(aa)s)
{
PyErr_SetString(PyExc_AssertionError, "%(zz)s != %(aa)s");
%(fail)s
}
if (dbeta != 0) if (dbeta != 0)
{ {
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT) if (PyArray_CopyInto(%(z)s, %(y)s) != 0) {
{
float * zoutdata = (float*)PyArray_DATA(%(zz)s);
const float * zdata = (float*)PyArray_DATA(%(aa)s);
int Ai = PyArray_STRIDES(%(aa)s)[0]/sizeof(float);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(float);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = fbeta * zdata[Ai*i];
}
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
double * zoutdata = (double*) PyArray_DATA(%(zz)s);
const double * zdata = (double*)PyArray_DATA(%(aa)s);
int Ai = PyArray_STRIDES(%(aa)s)[0]/sizeof(double);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(double);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = dbeta * zdata[Ai*i];
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
fbeta = dbeta = 1.0;
} }
else if (%(force_init_beta)d) else if (%(force_init_beta)d)
{ {
if (PyArray_CHKFLAGS(%(zz)s, NPY_ARRAY_C_CONTIGUOUS)) PyObject *zero = PyFloat_FromDouble(0.);
{ if (zero == NULL) %(fail)s;
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_SIZE(%(zz)s)*PyArray_ITEMSIZE(%(zz)s)); if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
} Py_DECREF(zero);
else
{
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT)
{
float *zoutdata = (float *)PyArray_DATA(%(zz)s);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(float);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = 0.0f;
}
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
double *zoutdata = (double *)PyArray_DATA(%(zz)s);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(double);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = 0.0;
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
} }
} }
else else
{ {
//fprintf(stderr, "Gemv working in-place \\n"); if (%(z)s != %(y)s)
if (%(zz)s != %(aa)s)
{ {
if (%(zz)s) { Py_DECREF(%(zz)s); } Py_XDECREF(%(z)s);
%(zz)s = %(aa)s; %(z)s = %(y)s;
Py_INCREF(%(zz)s); Py_INCREF(%(z)s);
} }
} }
{ {
char TRANS = 'T'; char TRANS = 'T';
char NOTRANS = 'N'; char NOTRANS = 'N';
int Nx0 = PyArray_DIMS(%(xx)s)[0]; int NA0 = PyArray_DIMS(%(A)s)[0];
int Nx1 = PyArray_DIMS(%(xx)s)[1]; int NA1 = PyArray_DIMS(%(A)s)[1];
/* This formula is needed in the case where xx is actually a row or /* This formula is needed in the case where A is actually a row or
* column matrix, because BLAS sometimes insists that the strides: * column matrix, because BLAS sometimes insists that the strides:
* - are not smaller than the number of elements in the array * - are not smaller than the number of elements in the array
* - are not 0. * - are not 0.
*/ */
int Sx0 = (Nx0 > 1) ? (PyArray_STRIDES(%(xx)s)[0] / elemsize) : (Nx1 + 1); int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
int Sx1 = (Nx1 > 1) ? (PyArray_STRIDES(%(xx)s)[1] / elemsize) : (Nx0 + 1); int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
int Sz = PyArray_STRIDES(%(zz)s)[0] / elemsize; int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(yy)s)[0] / elemsize; int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(yy)s* yy_data = (dtype_%(yy)s*) PyArray_DATA(%(yy)s); dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(zz)s* zz_data = (dtype_%(zz)s*) PyArray_DATA(%(zz)s); dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays, // gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element, // but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one. // so when the stride is negative, we need to get the last one.
if (Sy < 0) if (Sx < 0)
yy_data += (Nx1 - 1) * Sy; x_data += (NA1 - 1) * Sx;
if (Sz < 0) if (Sz < 0)
zz_data += (Nx0 - 1) * Sz; z_data += (NA0 - 1) * Sz;
if (Nx0 * Nx1) if (NA0 * NA1)
{ {
// If xx is neither C- nor F-contiguous, we make a copy. // If A is neither C- nor F-contiguous, we make a copy.
// TODO: // TODO:
// - if one stride is equal to "- elemsize", we can still call // - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors // gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on // - if the copy is too long, maybe call vector/vector dot on
// each row instead // each row instead
if ((PyArray_STRIDES(%(xx)s)[0] < 0) if ((PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(xx)s)[1] < 0) || (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(xx)s)[0] != elemsize) || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
&& (PyArray_STRIDES(%(xx)s)[1] != elemsize))) && (PyArray_STRIDES(%(A)s)[1] != elemsize)))
{ {
npy_intp dims[2]; npy_intp dims[2];
dims[0] = Nx0; dims[0] = NA0;
dims[1] = Nx1; dims[1] = NA1;
PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy( PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
%(xx)s); %(A)s);
if (!xx_copy) if (!A_copy)
%(fail)s %(fail)s
Py_XDECREF(%(xx)s); Py_XDECREF(%(A)s);
%(xx)s = xx_copy; %(A)s = A_copy;
Sx0 = (Nx0 > 1) ? (PyArray_STRIDES(%(xx)s)[0] / elemsize) : (Nx1 + 1); SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
Sx1 = (Nx1 > 1) ? (PyArray_STRIDES(%(xx)s)[1] / elemsize) : (Nx0 + 1); SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
} }
if (PyArray_STRIDES(%(xx)s)[0] == elemsize) if (PyArray_STRIDES(%(A)s)[0] == elemsize)
{ {
if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &Nx0, &Nx1, sgemv_(&NOTRANS, &NA0, &NA1,
&alpha, &alpha,
(float*)(PyArray_DATA(%(xx)s)), &Sx1, (float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)yy_data, &Sy, (float*)x_data, &Sx,
&fbeta, &fbeta,
(float*)zz_data, &Sz); (float*)z_data, &Sz);
} }
else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &Nx0, &Nx1, dgemv_(&NOTRANS, &NA0, &NA1,
&alpha, &alpha,
(double*)(PyArray_DATA(%(xx)s)), &Sx1, (double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)yy_data, &Sy, (double*)x_data, &Sx,
&dbeta, &dbeta,
(double*)zz_data, &Sz); (double*)z_data, &Sz);
} }
else else
{ {
...@@ -598,52 +502,62 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, ...@@ -598,52 +502,62 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
%(fail)s %(fail)s
} }
} }
else if (PyArray_STRIDES(%(xx)s)[1] == elemsize) else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
{ {
if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{ {
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
// Check for vector-vector dot (Nx0 == 1). The code may work // Check for vector-vector dot (NA0 == 1). The code may work
// for Sx1 != 1 as well, but has not been tested for this case, // for SA1 != 1 as well, but has not been tested for this case,
// so Sx1 == 1 is required for safety. // so SA1 == 1 is required for safety.
if (Nx0 == 1 && Sx1 == 1) if (NA0 == 1 && SA1 == 1)
{ {
zz_data[0] = fbeta*zz_data[0] + alpha*sdot_(&Nx1, if (fbeta != 0.f) {
(float*)(PyArray_DATA(%(xx)s)), &Sx1, z_data[0] = fbeta*z_data[0];
(float*)yy_data, &Sy); } else {
z_data[0] = 0.f;
}
z_data[0] += alpha*sdot_(&NA1,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)x_data, &Sx);
} }
else else
{ {
sgemv_(&TRANS, &Nx1, &Nx0, sgemv_(&TRANS, &NA1, &NA0,
&alpha, &alpha,
(float*)(PyArray_DATA(%(xx)s)), &Sx0, (float*)(PyArray_DATA(%(A)s)), &SA0,
(float*)yy_data, &Sy, (float*)x_data, &Sx,
&fbeta, &fbeta,
(float*)zz_data, &Sz); (float*)z_data, &Sz);
} }
} }
else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
// Check for vector-vector dot (Nx0 == 1). The code may work // Check for vector-vector dot (NA0 == 1). The code may work
// for Sx1 != 1 as well, but has not been tested for this case, // for SA1 != 1 as well, but has not been tested for this case,
// so Sx1 == 1 is required for safety. // so SA1 == 1 is required for safety.
if (Nx0 == 1 && Sx1 == 1) if (NA0 == 1 && SA1 == 1)
{ {
zz_data[0] = dbeta*zz_data[0] + alpha*ddot_(&Nx1, if (dbeta != 0.) {
(double*)(PyArray_DATA(%(xx)s)), &Sx1, z_data[0] = dbeta*z_data[0];
(double*)yy_data, &Sy); } else {
z_data[0] = 0.;
}
z_data[0] += alpha*ddot_(&NA1,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)x_data, &Sx);
} }
else else
{ {
dgemv_(&TRANS, &Nx1, &Nx0, dgemv_(&TRANS, &NA1, &NA0,
&alpha, &alpha,
(double*)(PyArray_DATA(%(xx)s)), &Sx0, (double*)(PyArray_DATA(%(A)s)), &SA0,
(double*)yy_data, &Sy, (double*)x_data, &Sx,
&dbeta, &dbeta,
(double*)zz_data, &Sz); (double*)z_data, &Sz);
} }
} }
else else
...@@ -665,95 +579,89 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, ...@@ -665,95 +579,89 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
{ {
// the matrix has at least one dim of length 0 // the matrix has at least one dim of length 0
// so we do this loop, which either iterates over 0 elements // so we do this loop, which either iterates over 0 elements
// or else it does the right thing for length-0 x. // or else it does the right thing for length-0 A.
dtype_%(zz)s * zptr = (dtype_%(zz)s*)(PyArray_DATA(%(zz)s)); dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
for (int i = 0; i < Nx0; ++i) for (int i = 0; i < NA0; ++i)
{ {
zptr[i * Sz] *= dbeta; zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
} }
} }
} }
""" """
return code % locals() return code % locals()
class CGemv(BaseBLAS, Gemv): class CGemv(BaseBLAS, Gemv):
def __init__(self, inplace, force_init_beta=False): def __init__(self, inplace):
super(CGemv, self).__init__(inplace) super(CGemv, self).__init__(inplace)
self.force_init_beta = force_init_beta
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
aa, alpha, xx, yy, beta = inp y, alpha, A, x, beta = inp
zz, = out z, = out
code = gemv_c_code( code = gemv_c_code(
aa, xx, yy, zz, alpha, beta, y, A, x, z, alpha, beta,
destructive=int(self.inplace), destructive=int(self.inplace),
fail=sub['fail'], fail=sub['fail'],
force_init_beta=self.force_init_beta force_init_beta=check_force_gemv_init()
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (12, blas_header_version()) return (13, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init(): def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None: if check_force_gemv_init._force_init_beta is None:
from theano.gof.cmodule import GCC_compiler
""" """
Test issue 1569. Test issue 1569.
Namely when evaulating Namely when evaluating
beta*aa + alpha*dot(xx, yy) beta*y + alpha*dot(A, x)
where we set aa = betas = zeros of the correct dimensions we do not where we set y * beta = zeros of the correct dimensions we
actually set aa = zeros and instead let the BLAS perform beta*aa with do not actually set y = zeros and instead let the BLAS
uninitialized memory for speed. Occasionally the memory contains values perform beta*y with uninitialized memory for
that are equivalent to NaN in which case the product beta*aa contains speed. Occasionally the memory contains values that are
NaN's for correctly implemented BLAS libraries. In this situation, since equivalent to NaN in which case the product beta*y contains
we are introducing the NaN's, we need to test whether the BLAS performs NaN's for correctly implemented BLAS libraries. In this
correctly. If it *does*, i.e. it actually performs the multiplication situation, since we are introducing the NaN's, we need to test
beta*aa which will result in NaN's in the result, then we need intialize whether the BLAS performs correctly. If it *does*, i.e. it
the memory to zeros. actually performs the multiplication beta*y which will result
in NaN's in the result, then we need intialize the memory to
zeros.
""" """
tv = theano.config.compute_test_value test_code = """
tvo = theano.config.compute_test_value_opt #include <math.h>
theano.config.compute_test_value = 'off' extern "C" void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int *);
theano.config.compute_test_value_opt = 'off' int main() {
try: double A[2][2] = {{1., 1.}, {1., 1.}};
aa = T.vector('aa') double x[2] = {1., 1.};
yy = T.vector('yy') double y[2] = {NAN, NAN};
xx = T.matrix('xx') const int s = 2;
f = theano.function( const int inc = 1;
[aa, yy, xx], const double alpha = 1.0;
gemv_no_inplace(aa, 1., xx, yy, 0.), const double beta = 0.0;
theano.compile.Mode(optimizer='fast_compile').excluding('gpu',
'gpuarray'), dgemv_("T", &s, &s, &alpha, A, &s, x, &inc, &beta, &y, &inc);
profile=False)
finally: return (isnan(y[0]) || isnan(y[1]) ? 1 : 0;
theano.config.compute_test_value = tv }
theano.config.compute_test_value_opt = tvo """
res = GCC_compiler.try_compile_tmp(test_code, tmp_prefix='check_beta_',
# Here we introduce NaNs into the data, if they are returned by the BLAS flags=ldflags(libs=True, flags=True,
# then we want gemv_c_code to initiliaze the memory to 0 so that we libs_dir=True),
# don't inadvertantly introduce NaNs to the users data. try_run=True)
aa_data = numpy.array( if res:
float('NaN') * numpy.ones((2,)), if res[0]:
dtype=theano.config.floatX check_force_gemv_init._force_init_beta = res[1]
) else:
yy_data = numpy.array( check_force_gemv_init._force_init_beta = False
numpy.ones((2,)) * 2, else:
dtype=theano.config.floatX check_force_gemv_init._force_init_beta = False
)
xx_data = numpy.array(
numpy.ones((2, 2)),
dtype=theano.config.floatX
)
zz = f(aa_data, yy_data, xx_data)
check_force_gemv_init._force_init_beta = numpy.isnan(zz).any()
return check_force_gemv_init._force_init_beta return check_force_gemv_init._force_init_beta
...@@ -767,33 +675,10 @@ def use_c_gemv(node): ...@@ -767,33 +675,10 @@ def use_c_gemv(node):
# Only float32 and float64 are supported for now. # Only float32 and float64 are supported for now.
if (node.op == gemv_no_inplace and if (node.op == gemv_no_inplace and
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [cgemv_no_inplace(*node.inputs)]
"""
We want to maintain the behavoir of any operation that the user adds
even if it results in NaNs. However we do not want optimizations to
introduce NaNs.
GEMV is not always implemented consistenly across BLAS libraries.
Sometimes, when beta is 0, they do not perform the multiplication with
beta. Other implmentations do. This can cause problems for the inplace
GEMV implementation if NaNs happen to be in the newly allocated but
uninitalized memory. When the multiplication is not done we do not need
to initialize the output memory resulting in a speed up. Otherwise we
must initialize the memory to avoid introducing NaN's in the output
that weren't in the original graph.
The following check determines whether the output memory needs to be
initiliazed. It is done here, as opposed to in global scope, because
the setup has not been completed at that time and therefore the check
cannot be performed at that time.
"""
force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
if (node.op == gemv_inplace and if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [cgemv_inplace(*node.inputs)]
@local_optimizer([CGemv(inplace=False)]) @local_optimizer([CGemv(inplace=False)])
......
...@@ -294,6 +294,11 @@ class TestCpuConv2d(BaseTestConv2d): ...@@ -294,6 +294,11 @@ class TestCpuConv2d(BaseTestConv2d):
def setUp(self): def setUp(self):
super(TestCpuConv2d, self).setUp() super(TestCpuConv2d, self).setUp()
self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm') self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm')
self.opt_err = theano.config.on_opt_error
theano.config.on_opt_error = 'ignore'
def tearDown(self):
theano.config.on_opt_error = self.opt_err
def tcase(self, i, f, s, b, flip, provide_shape): def tcase(self, i, f, s, b, flip, provide_shape):
mode = self.mode mode = self.mode
......
...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin):
# scalar # scalar
self.a = tensor.tensor(dtype=dtype, broadcastable=()) self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self):
f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode)
Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not numpy.isnan(zval).any()
def test_optimizations_vm(self): def test_optimizations_vm(self):
''' Test vector dot matrix ''' ''' Test vector dot matrix '''
f = theano.function([self.x, self.A], f = theano.function([self.x, self.A],
...@@ -140,7 +150,7 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -140,7 +150,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot) self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1( self.assertFunctionContains1(
f, f,
CGemv(inplace=True, force_init_beta=True) CGemv(inplace=True)
) )
# Assert they produce the same output # Assert they produce the same output
...@@ -161,7 +171,7 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -161,7 +171,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot) self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1( self.assertFunctionContains1(
f, f,
CGemv(inplace=True, force_init_beta=True) CGemv(inplace=True)
) )
# Assert they produce the same output # Assert they produce the same output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论