提交 02468f85 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix gemv to accept both float32 and float64 and reenable optimization to insert it in the graph.

上级 52dcfe33
from theano import Apply, config
from theano import Op, Apply, config
from theano.tensor.blas import Gemv
from theano.sandbox.gpuarray.basic_ops import (HideC, as_gpuarray_variable)
try:
import pygpu
from pygpu import blas
except ImportError, e:
# To make sure theano is importable
pass
class GpuGemv(HideC, Gemv):
class BlasOp(HideC, Op):
def c_headers(self):
return ['<blas_api.h>']
def c_header_dirs(self):
return [pygpu.get_include()]
def c_init_code(self):
return ['import_pygpu__blas();']
class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta):
res = Gemv.make_node(self, y, alpha, A, x, beta)
A = as_gpuarray_variable(A)
......@@ -19,8 +30,8 @@ class GpuGemv(HideC, Gemv):
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
out_storage[0][0] = blas.sgemv(alpha, A, x, beta, y, trans=False,
overwrite_y=self.inplace)
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y, trans=False,
overwrite_y=self.inplace)
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
......@@ -40,7 +51,7 @@ class GpuGemv(HideC, Gemv):
}
""" % vars
code += """
if (pygpu_blas_sgemv(cb_no_trans,
if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
......@@ -54,6 +65,9 @@ class GpuGemv(HideC, Gemv):
"""
return code
def c_code_cache_version(self):
return (0,)
gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True)
......
......@@ -180,7 +180,7 @@ def local_gpua_careduce(node):
dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None))
#@register_opt()
@register_opt()
@op_lifter(tensor.blas.Gemv)
def local_gpua_gemv(node):
return GpuGemv(inplace=node.op.inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论