提交 0babc678 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4168 from abergeron/fix_cgemv_nan

Fix nan handling in output buffer for CGemv
import numpy
from theano import config
from theano.tensor.opt import in2out
......@@ -8,7 +6,6 @@ from theano.tensor.blas import blas_optdb, optdb, local_optimizer
from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T
import theano.compile
class BaseBLAS(object):
......@@ -167,7 +164,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
else
{
//fprintf(stderr, "USING A\\n");
if (%(Z)s != %(A)s)
{
if (%(Z)s) { Py_DECREF(%(Z)s); }
......@@ -253,7 +249,6 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)x_data, &Sx,
......@@ -353,65 +348,34 @@ def make_c_ger_destructive(node):
# ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
force_init_beta=False):
"""
zz <- beta * aa + alpha * dot(xx, yy)
z <- beta * y + alpha * dot(A, x)
where xx is a matrix, yy and aa are vectors (ergo zz is vector)
where A is a matrix, y and x are vectors (ergo z is vector)
"""
code = """
int elemsize ;
int elemsize;
float fbeta;
double dbeta;
if (PyArray_NDIM(%(aa)s) != 1)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1");
%(fail)s;
}
if (PyArray_NDIM(%(xx)s) != 2)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2");
%(fail)s;
}
if (PyArray_NDIM(%(yy)s) != 1)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1");
%(fail)s;
}
if (PyArray_NDIM(%(alpha)s) != 0)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0");
%(fail)s;
}
if (PyArray_NDIM(%(beta)s) != 0)
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0");
%(fail)s;
}
if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(xx)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(yy)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (PyArray_DIMS(%(xx)s)[0] != PyArray_DIMS(%(aa)s)[0])
if (PyArray_DIMS(%(A)s)[0] != PyArray_DIMS(%(y)s)[0])
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]");
"Shape mismatch: A.shape[0] != y.shape[0]");
%(fail)s;
}
if (PyArray_DIMS(%(xx)s)[1] != PyArray_DIMS(%(yy)s)[0])
if (PyArray_DIMS(%(A)s)[1] != PyArray_DIMS(%(x)s)[0])
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]");
"Shape mismatch: A.shape[1] != x.shape[0]");
%(fail)s;
}
if (PyArray_DESCR(%(aa)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
else if (PyArray_DESCR(%(aa)s)->type_num == NPY_FLOAT) { elemsize = 4;}
if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;}
else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s;
......@@ -419,177 +383,117 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy aa if not destructive
// copy y if not destructive
if (!%(destructive)s)
{
if ((NULL == %(zz)s)
|| (PyArray_DIMS(%(zz)s)[0] != PyArray_DIMS(%(aa)s)[0]))
if ((NULL == %(z)s)
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
{
Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(aa)s), PyArray_TYPE(%(aa)s));
if(!%(zz)s) {
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(y)s), PyArray_TYPE(%(y)s));
if(!%(z)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output");
%(fail)s
}
}
if (%(zz)s == %(aa)s)
{
PyErr_SetString(PyExc_AssertionError, "%(zz)s != %(aa)s");
%(fail)s
}
if (dbeta != 0)
{
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT)
{
float * zoutdata = (float*)PyArray_DATA(%(zz)s);
const float * zdata = (float*)PyArray_DATA(%(aa)s);
int Ai = PyArray_STRIDES(%(aa)s)[0]/sizeof(float);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(float);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = fbeta * zdata[Ai*i];
}
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
double * zoutdata = (double*) PyArray_DATA(%(zz)s);
const double * zdata = (double*)PyArray_DATA(%(aa)s);
int Ai = PyArray_STRIDES(%(aa)s)[0]/sizeof(double);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(double);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = dbeta * zdata[Ai*i];
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
if (PyArray_CopyInto(%(z)s, %(y)s) != 0) {
%(fail)s
}
fbeta = dbeta = 1.0;
}
else if (%(force_init_beta)d)
{
if (PyArray_CHKFLAGS(%(zz)s, NPY_ARRAY_C_CONTIGUOUS))
{
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_SIZE(%(zz)s)*PyArray_ITEMSIZE(%(zz)s));
}
else
{
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT)
{
float *zoutdata = (float *)PyArray_DATA(%(zz)s);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(float);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = 0.0f;
}
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
double *zoutdata = (double *)PyArray_DATA(%(zz)s);
int Zi = PyArray_STRIDES(%(zz)s)[0]/sizeof(double);
for (int i = 0; i < PyArray_DIMS(%(aa)s)[0]; ++i)
{
zoutdata[Zi*i] = 0.0;
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}
}
else
{
//fprintf(stderr, "Gemv working in-place \\n");
if (%(zz)s != %(aa)s)
if (%(z)s != %(y)s)
{
if (%(zz)s) { Py_DECREF(%(zz)s); }
%(zz)s = %(aa)s;
Py_INCREF(%(zz)s);
Py_XDECREF(%(z)s);
%(z)s = %(y)s;
Py_INCREF(%(z)s);
}
}
{
char TRANS = 'T';
char NOTRANS = 'N';
int Nx0 = PyArray_DIMS(%(xx)s)[0];
int Nx1 = PyArray_DIMS(%(xx)s)[1];
/* This formula is needed in the case where xx is actually a row or
int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1];
/* This formula is needed in the case where A is actually a row or
* column matrix, because BLAS sometimes insists that the strides:
* - are not smaller than the number of elements in the array
* - are not 0.
*/
int Sx0 = (Nx0 > 1) ? (PyArray_STRIDES(%(xx)s)[0] / elemsize) : (Nx1 + 1);
int Sx1 = (Nx1 > 1) ? (PyArray_STRIDES(%(xx)s)[1] / elemsize) : (Nx0 + 1);
int Sz = PyArray_STRIDES(%(zz)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(yy)s)[0] / elemsize;
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
dtype_%(yy)s* yy_data = (dtype_%(yy)s*) PyArray_DATA(%(yy)s);
dtype_%(zz)s* zz_data = (dtype_%(zz)s*) PyArray_DATA(%(zz)s);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sy < 0)
yy_data += (Nx1 - 1) * Sy;
if (Sx < 0)
x_data += (NA1 - 1) * Sx;
if (Sz < 0)
zz_data += (Nx0 - 1) * Sz;
z_data += (NA0 - 1) * Sz;
if (Nx0 * Nx1)
if (NA0 * NA1)
{
// If xx is neither C- nor F-contiguous, we make a copy.
// If A is neither C- nor F-contiguous, we make a copy.
// TODO:
// - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
if ((PyArray_STRIDES(%(xx)s)[0] < 0)
|| (PyArray_STRIDES(%(xx)s)[1] < 0)
|| ((PyArray_STRIDES(%(xx)s)[0] != elemsize)
&& (PyArray_STRIDES(%(xx)s)[1] != elemsize)))
if ((PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
{
npy_intp dims[2];
dims[0] = Nx0;
dims[1] = Nx1;
dims[0] = NA0;
dims[1] = NA1;
PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(
%(xx)s);
if (!xx_copy)
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
%(A)s);
if (!A_copy)
%(fail)s
Py_XDECREF(%(xx)s);
%(xx)s = xx_copy;
Sx0 = (Nx0 > 1) ? (PyArray_STRIDES(%(xx)s)[0] / elemsize) : (Nx1 + 1);
Sx1 = (Nx1 > 1) ? (PyArray_STRIDES(%(xx)s)[1] / elemsize) : (Nx0 + 1);
Py_XDECREF(%(A)s);
%(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
}
if (PyArray_STRIDES(%(xx)s)[0] == elemsize)
if (PyArray_STRIDES(%(A)s)[0] == elemsize)
{
if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT)
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
//fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &Nx0, &Nx1,
sgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(float*)(PyArray_DATA(%(xx)s)), &Sx1,
(float*)yy_data, &Sy,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)x_data, &Sx,
&fbeta,
(float*)zz_data, &Sz);
(float*)z_data, &Sz);
}
else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
dgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(double*)(PyArray_DATA(%(xx)s)), &Sx1,
(double*)yy_data, &Sy,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)x_data, &Sx,
&dbeta,
(double*)zz_data, &Sz);
(double*)z_data, &Sz);
}
else
{
......@@ -598,52 +502,62 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
%(fail)s
}
}
else if (PyArray_STRIDES(%(xx)s)[1] == elemsize)
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)
{
if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT)
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
// Check for vector-vector dot (Nx0 == 1). The code may work
// for Sx1 != 1 as well, but has not been tested for this case,
// so Sx1 == 1 is required for safety.
if (Nx0 == 1 && Sx1 == 1)
// Check for vector-vector dot (NA0 == 1). The code may work
// for SA1 != 1 as well, but has not been tested for this case,
// so SA1 == 1 is required for safety.
if (NA0 == 1 && SA1 == 1)
{
zz_data[0] = fbeta*zz_data[0] + alpha*sdot_(&Nx1,
(float*)(PyArray_DATA(%(xx)s)), &Sx1,
(float*)yy_data, &Sy);
if (fbeta != 0.f) {
z_data[0] = fbeta*z_data[0];
} else {
z_data[0] = 0.f;
}
z_data[0] += alpha*sdot_(&NA1,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)x_data, &Sx);
}
else
{
sgemv_(&TRANS, &Nx1, &Nx0,
sgemv_(&TRANS, &NA1, &NA0,
&alpha,
(float*)(PyArray_DATA(%(xx)s)), &Sx0,
(float*)yy_data, &Sy,
(float*)(PyArray_DATA(%(A)s)), &SA0,
(float*)x_data, &Sx,
&fbeta,
(float*)zz_data, &Sz);
(float*)z_data, &Sz);
}
}
else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
// Check for vector-vector dot (Nx0 == 1). The code may work
// for Sx1 != 1 as well, but has not been tested for this case,
// so Sx1 == 1 is required for safety.
if (Nx0 == 1 && Sx1 == 1)
// Check for vector-vector dot (NA0 == 1). The code may work
// for SA1 != 1 as well, but has not been tested for this case,
// so SA1 == 1 is required for safety.
if (NA0 == 1 && SA1 == 1)
{
zz_data[0] = dbeta*zz_data[0] + alpha*ddot_(&Nx1,
(double*)(PyArray_DATA(%(xx)s)), &Sx1,
(double*)yy_data, &Sy);
if (dbeta != 0.) {
z_data[0] = dbeta*z_data[0];
} else {
z_data[0] = 0.;
}
z_data[0] += alpha*ddot_(&NA1,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)x_data, &Sx);
}
else
{
dgemv_(&TRANS, &Nx1, &Nx0,
dgemv_(&TRANS, &NA1, &NA0,
&alpha,
(double*)(PyArray_DATA(%(xx)s)), &Sx0,
(double*)yy_data, &Sy,
(double*)(PyArray_DATA(%(A)s)), &SA0,
(double*)x_data, &Sx,
&dbeta,
(double*)zz_data, &Sz);
(double*)z_data, &Sz);
}
}
else
......@@ -665,95 +579,89 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
{
// the matrix has at least one dim of length 0
// so we do this loop, which either iterates over 0 elements
// or else it does the right thing for length-0 x.
dtype_%(zz)s * zptr = (dtype_%(zz)s*)(PyArray_DATA(%(zz)s));
for (int i = 0; i < Nx0; ++i)
// or else it does the right thing for length-0 A.
dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
for (int i = 0; i < NA0; ++i)
{
zptr[i * Sz] *= dbeta;
zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
}
}
}
"""
return code % locals()
class CGemv(BaseBLAS, Gemv):
def __init__(self, inplace, force_init_beta=False):
def __init__(self, inplace):
super(CGemv, self).__init__(inplace)
self.force_init_beta = force_init_beta
def c_code(self, node, name, inp, out, sub):
aa, alpha, xx, yy, beta = inp
zz, = out
y, alpha, A, x, beta = inp
z, = out
code = gemv_c_code(
aa, xx, yy, zz, alpha, beta,
y, A, x, z, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'],
force_init_beta=self.force_init_beta
force_init_beta=check_force_gemv_init()
)
return code
def c_code_cache_version(self):
return (12, blas_header_version())
return (13, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None:
from theano.gof.cmodule import GCC_compiler
"""
Test issue 1569.
Namely when evaulating
beta*aa + alpha*dot(xx, yy)
where we set aa = betas = zeros of the correct dimensions we do not
actually set aa = zeros and instead let the BLAS perform beta*aa with
uninitialized memory for speed. Occasionally the memory contains values
that are equivalent to NaN in which case the product beta*aa contains
NaN's for correctly implemented BLAS libraries. In this situation, since
we are introducing the NaN's, we need to test whether the BLAS performs
correctly. If it *does*, i.e. it actually performs the multiplication
beta*aa which will result in NaN's in the result, then we need intialize
the memory to zeros.
Namely when evaluating
beta*y + alpha*dot(A, x)
where we set y * beta = zeros of the correct dimensions we
do not actually set y = zeros and instead let the BLAS
perform beta*y with uninitialized memory for
speed. Occasionally the memory contains values that are
equivalent to NaN in which case the product beta*y contains
NaN's for correctly implemented BLAS libraries. In this
situation, since we are introducing the NaN's, we need to test
whether the BLAS performs correctly. If it *does*, i.e. it
actually performs the multiplication beta*y which will result
in NaN's in the result, then we need intialize the memory to
zeros.
"""
tv = theano.config.compute_test_value
tvo = theano.config.compute_test_value_opt
theano.config.compute_test_value = 'off'
theano.config.compute_test_value_opt = 'off'
try:
aa = T.vector('aa')
yy = T.vector('yy')
xx = T.matrix('xx')
f = theano.function(
[aa, yy, xx],
gemv_no_inplace(aa, 1., xx, yy, 0.),
theano.compile.Mode(optimizer='fast_compile').excluding('gpu',
'gpuarray'),
profile=False)
finally:
theano.config.compute_test_value = tv
theano.config.compute_test_value_opt = tvo
# Here we introduce NaNs into the data, if they are returned by the BLAS
# then we want gemv_c_code to initiliaze the memory to 0 so that we
# don't inadvertantly introduce NaNs to the users data.
aa_data = numpy.array(
float('NaN') * numpy.ones((2,)),
dtype=theano.config.floatX
)
yy_data = numpy.array(
numpy.ones((2,)) * 2,
dtype=theano.config.floatX
)
xx_data = numpy.array(
numpy.ones((2, 2)),
dtype=theano.config.floatX
)
zz = f(aa_data, yy_data, xx_data)
check_force_gemv_init._force_init_beta = numpy.isnan(zz).any()
test_code = """
#include <math.h>
extern "C" void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int *);
int main() {
double A[2][2] = {{1., 1.}, {1., 1.}};
double x[2] = {1., 1.};
double y[2] = {NAN, NAN};
const int s = 2;
const int inc = 1;
const double alpha = 1.0;
const double beta = 0.0;
dgemv_("T", &s, &s, &alpha, A, &s, x, &inc, &beta, &y, &inc);
return (isnan(y[0]) || isnan(y[1]) ? 1 : 0;
}
"""
res = GCC_compiler.try_compile_tmp(test_code, tmp_prefix='check_beta_',
flags=ldflags(libs=True, flags=True,
libs_dir=True),
try_run=True)
if res:
if res[0]:
check_force_gemv_init._force_init_beta = res[1]
else:
check_force_gemv_init._force_init_beta = False
else:
check_force_gemv_init._force_init_beta = False
return check_force_gemv_init._force_init_beta
......@@ -767,33 +675,10 @@ def use_c_gemv(node):
# Only float32 and float64 are supported for now.
if (node.op == gemv_no_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
"""
We want to maintain the behavoir of any operation that the user adds
even if it results in NaNs. However we do not want optimizations to
introduce NaNs.
GEMV is not always implemented consistenly across BLAS libraries.
Sometimes, when beta is 0, they do not perform the multiplication with
beta. Other implmentations do. This can cause problems for the inplace
GEMV implementation if NaNs happen to be in the newly allocated but
uninitalized memory. When the multiplication is not done we do not need
to initialize the output memory resulting in a speed up. Otherwise we
must initialize the memory to avoid introducing NaN's in the output
that weren't in the original graph.
The following check determines whether the output memory needs to be
initiliazed. It is done here, as opposed to in global scope, because
the setup has not been completed at that time and therefore the check
cannot be performed at that time.
"""
force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
return [cgemv_no_inplace(*node.inputs)]
if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)]
return [cgemv_inplace(*node.inputs)]
@local_optimizer([CGemv(inplace=False)])
......
......@@ -294,6 +294,11 @@ class TestCpuConv2d(BaseTestConv2d):
def setUp(self):
super(TestCpuConv2d, self).setUp()
self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm')
self.opt_err = theano.config.on_opt_error
theano.config.on_opt_error = 'ignore'
def tearDown(self):
theano.config.on_opt_error = self.opt_err
def tcase(self, i, f, s, b, flip, provide_shape):
mode = self.mode
......
......@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin):
# scalar
self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self):
f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode)
Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not numpy.isnan(zval).any()
def test_optimizations_vm(self):
''' Test vector dot matrix '''
f = theano.function([self.x, self.A],
......@@ -140,7 +150,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1(
f,
CGemv(inplace=True, force_init_beta=True)
CGemv(inplace=True)
)
# Assert they produce the same output
......@@ -161,7 +171,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1(
f,
CGemv(inplace=True, force_init_beta=True)
CGemv(inplace=True)
)
# Assert they produce the same output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论