提交 a3227224 authored 作者: Dustin Webb's avatar Dustin Webb

Optimized code when array is c contiguous and added test to warn user.

上级 053c63fb
...@@ -405,20 +405,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta= ...@@ -405,20 +405,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=
{ {
if (PyArray_CHKFLAGS(%(zz)s, NPY_ARRAY_C_CONTIGUOUS)) if (PyArray_CHKFLAGS(%(zz)s, NPY_ARRAY_C_CONTIGUOUS))
{ {
if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT) memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_SIZE(%(zz)s)*PyArray_ITEMSIZE(%(zz)s));
{
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_ITEMSIZE(%(zz)s));
}
else if (PyArray_DESCR(%(zz)s)->type_num == NPY_DOUBLE)
{
memset((void *)PyArray_DATA(%(zz)s), 0, PyArray_ITEMSIZE(%(zz)s));
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
} }
else else
{ {
......
...@@ -15,6 +15,8 @@ from theano.tensor.blas_c import CGemv ...@@ -15,6 +15,8 @@ from theano.tensor.blas_c import CGemv
from theano.tensor.blas_scipy import ScipyGer from theano.tensor.blas_scipy import ScipyGer
from theano.tensor.blas import Gemv from theano.tensor.blas import Gemv
from theano.tensor.blas_c import check_force_gemv_init
from theano.tests import unittest_tools from theano.tests import unittest_tools
from theano.tests.unittest_tools import TestOptimizationMixin from theano.tests.unittest_tools import TestOptimizationMixin
...@@ -170,6 +172,14 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -170,6 +172,14 @@ class TestCGemv(TestCase, TestOptimizationMixin):
assert numpy.allclose(f(self.Aval[::-1, ::-1], self.yval), assert numpy.allclose(f(self.Aval[::-1, ::-1], self.yval),
numpy.dot(self.Aval[::-1, ::-1], self.yval)) numpy.dot(self.Aval[::-1, ::-1], self.yval))
def test_force_gemv_init(self):
if check_force_gemv_init():
sys.stderr.write(
"WARNING: The current BLAS requires Theano to initialize"
+ " memory for some GEMV calls which will result in a minor"
+ " degradation in performance for such calls."
)
def t_gemv1(self, m_shp): def t_gemv1(self, m_shp):
''' test vector2 + dot(matrix, vector1) ''' ''' test vector2 + dot(matrix, vector1) '''
rng = numpy.random.RandomState(unittest_tools.fetch_seed()) rng = numpy.random.RandomState(unittest_tools.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论