提交 a3227224 authored 作者: Dustin Webb's avatar Dustin Webb

Optimized code when array is c contiguous and added test to warn user.

上级 053c63fb
......@@ -405,20 +405,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=
{
if (PyArray_CHKFLAGS(%(zz)s, NPY_ARRAY_C_CONTIGUOUS))
{
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT)
{
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_ITEMSIZE(%(zz)s));
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_ITEMSIZE(%(zz)s));
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_SIZE(%(zz)s)*PyArray_ITEMSIZE(%(zz)s));
}
else
{
......
......@@ -15,6 +15,8 @@ from theano.tensor.blas_c import CGemv
from theano.tensor.blas_scipy import ScipyGer
from theano.tensor.blas import Gemv
from theano.tensor.blas_c import check_force_gemv_init
from theano.tests import unittest_tools
from theano.tests.unittest_tools import TestOptimizationMixin
......@@ -170,6 +172,14 @@ class TestCGemv(TestCase, TestOptimizationMixin):
assert numpy.allclose(f(self.Aval[::-1, ::-1], self.yval),
numpy.dot(self.Aval[::-1, ::-1], self.yval))
def test_force_gemv_init(self):
if check_force_gemv_init():
sys.stderr.write(
"WARNING: The current BLAS requires Theano to initialize"
+ " memory for some GEMV calls which will result in a minor"
+ " degradation in performance for such calls."
)
def t_gemv1(self, m_shp):
''' test vector2 + dot(matrix, vector1) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论