提交 da2d163e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a python implementation of GpuGemv.

上级 8c278a05
......@@ -3,6 +3,12 @@ from theano import Apply, config
from theano.tensor.blas import Gemv
from theano.sandbox.gpuarray.basic_ops import (HideC, as_gpuarray_variable)
try:
from pygpu import blas
except ImportError, e:
# To make sure theano is importable
pass
class GpuGemv(HideC, Gemv):
def make_node(self, y, alpha, A, x, beta):
res = Gemv.make_node(self, y, alpha, A, x, beta)
......@@ -11,8 +17,10 @@ class GpuGemv(HideC, Gemv):
y = as_gpuarray_variable(y)
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(*args):
raise NotImplementedError
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
out_storage[0][0] = blas.sgemv(alpha, A, x, beta, y, trans=False,
overwrite_y=self.inplace)
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论