提交 3c332e72 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the blas ops for real this time ...

上级 452e9c92
import os.path import os.path
from theano import Apply, config from theano import Apply, config, Op
from theano.compile import optdb from theano.compile import optdb
from theano.gof import LocalOptGroup from theano.gof import LocalOptGroup
from theano.tensor.basic import as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from .basic_ops import HideC, as_gpuarray_variable, infer_context_name from .basic_ops import as_gpuarray_variable, infer_context_name
from .opt_util import inplace_allocempty from .opt_util import inplace_allocempty
...@@ -19,7 +19,7 @@ except ImportError as e: ...@@ -19,7 +19,7 @@ except ImportError as e:
pass pass
class BlasOp(HideC): class BlasOp(Op):
def c_headers(self): def c_headers(self):
return ['<blas_api.h>', '<numpy_compat.h>', '<gpuarray_helper.h>'] return ['<blas_api.h>', '<numpy_compat.h>', '<gpuarray_helper.h>']
...@@ -31,6 +31,8 @@ class BlasOp(HideC): ...@@ -31,6 +31,8 @@ class BlasOp(HideC):
class GpuGemv(BlasOp): class GpuGemv(BlasOp):
__props__ = ('inplace',)
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
...@@ -107,6 +109,7 @@ gpugemv_inplace = GpuGemv(inplace=True) ...@@ -107,6 +109,7 @@ gpugemv_inplace = GpuGemv(inplace=True)
class GpuGemm(BlasOp): class GpuGemm(BlasOp):
__props__ = ('inplace',)
_f16_ok = True _f16_ok = True
def __init__(self, inplace=False): def __init__(self, inplace=False):
...@@ -185,6 +188,8 @@ gpugemm_inplace = GpuGemm(inplace=True) ...@@ -185,6 +188,8 @@ gpugemm_inplace = GpuGemm(inplace=True)
class GpuGer(BlasOp): class GpuGer(BlasOp):
__props__ = ('inplace',)
def __init__(self, inplace=False): def __init__(self, inplace=False):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
...@@ -256,6 +261,8 @@ gpuger_inplace = GpuGer(inplace=True) ...@@ -256,6 +261,8 @@ gpuger_inplace = GpuGer(inplace=True)
class GpuDot22(BlasOp): class GpuDot22(BlasOp):
__props__ = ()
def make_node(self, x, y): def make_node(self, x, y):
ctx_name = infer_context_name(x, y) ctx_name = infer_context_name(x, y)
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论