提交 e0b3ba8d authored 作者: Harm de Vries's avatar Harm de Vries

rebase

上级 a7900bd8
...@@ -92,7 +92,14 @@ class GpuGemv(BlasOp): ...@@ -92,7 +92,14 @@ class GpuGemv(BlasOp):
} }
""" % vars """ % vars
code += """ code += """
if (pygpu_blas_rgemv(cb_no_trans, if (PyGpuArray_DIM(%(A)s, 1) == 0) {
int code;
code = GpuArray_memset(&%(out)s->ga, 0);
if (code != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Memset failed");
%(fail)s
}
} else if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], ((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s, %(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0], ((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
...@@ -107,7 +114,8 @@ class GpuGemv(BlasOp): ...@@ -107,7 +114,8 @@ class GpuGemv(BlasOp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return ()
# return (4,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division ...@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division
from unittest import TestCase from unittest import TestCase
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import itertools import itertools
import numpy
import theano import theano
from theano import tensor from theano import tensor
...@@ -128,3 +129,17 @@ GpuDot22Tester = makeTester( ...@@ -128,3 +129,17 @@ GpuDot22Tester = makeTester(
# test9=[rand(0, 0), rand(0, 0)], # test9=[rand(0, 0), rand(0, 0)],
) )
) )
def test_gemv_zeros():
W = tensor.matrix()
v = tensor.vector()
f = theano.function([W, v], W.dot(v), mode=mode_with_gpu)
# Apply to an empty matrix shape (5,0) and an empty vector shape (0,)
dim = 1000
A = numpy.zeros((dim, 0), dtype=theano.config.floatX)
b = numpy.zeros((0,), dtype=theano.config.floatX)
tmp = f(A, b)
assert numpy.allclose(tmp,
numpy.zeros((dim,)))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论