提交 02468f85 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix gemv to accept both float32 and float64 and reenable optimization to insert it in the graph.

上级 52dcfe33
from theano import Apply, config from theano import Op, Apply, config
from theano.tensor.blas import Gemv from theano.tensor.blas import Gemv
from theano.sandbox.gpuarray.basic_ops import (HideC, as_gpuarray_variable) from theano.sandbox.gpuarray.basic_ops import (HideC, as_gpuarray_variable)
try: try:
import pygpu
from pygpu import blas from pygpu import blas
except ImportError, e: except ImportError, e:
# To make sure theano is importable # To make sure theano is importable
pass pass
class GpuGemv(HideC, Gemv): class BlasOp(HideC, Op):
def c_headers(self):
return ['<blas_api.h>']
def c_header_dirs(self):
return [pygpu.get_include()]
def c_init_code(self):
return ['import_pygpu__blas();']
class GpuGemv(BlasOp, Gemv):
def make_node(self, y, alpha, A, x, beta): def make_node(self, y, alpha, A, x, beta):
res = Gemv.make_node(self, y, alpha, A, x, beta) res = Gemv.make_node(self, y, alpha, A, x, beta)
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
...@@ -19,8 +30,8 @@ class GpuGemv(HideC, Gemv): ...@@ -19,8 +30,8 @@ class GpuGemv(HideC, Gemv):
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
out_storage[0][0] = blas.sgemv(alpha, A, x, beta, y, trans=False, out_storage[0][0] = blas.gemv(alpha, A, x, beta, y, trans=False,
overwrite_y=self.inplace) overwrite_y=self.inplace)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3], vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
...@@ -40,7 +51,7 @@ class GpuGemv(HideC, Gemv): ...@@ -40,7 +51,7 @@ class GpuGemv(HideC, Gemv):
} }
""" % vars """ % vars
code += """ code += """
if (pygpu_blas_sgemv(cb_no_trans, if (pygpu_blas_rgemv(cb_no_trans,
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], ((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s, %(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0], ((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
...@@ -54,6 +65,9 @@ class GpuGemv(HideC, Gemv): ...@@ -54,6 +65,9 @@ class GpuGemv(HideC, Gemv):
""" """
return code return code
def c_code_cache_version(self):
return (0,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
......
...@@ -180,7 +180,7 @@ def local_gpua_careduce(node): ...@@ -180,7 +180,7 @@ def local_gpua_careduce(node):
dtype=getattr(node.op, 'dtype', None), dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None)) acc_dtype=getattr(node.op, 'acc_dtype', None))
#@register_opt() @register_opt()
@op_lifter(tensor.blas.Gemv) @op_lifter(tensor.blas.Gemv)
def local_gpua_gemv(node): def local_gpua_gemv(node):
return GpuGemv(inplace=node.op.inplace) return GpuGemv(inplace=node.op.inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论