提交 c67185bb authored 作者: notoraptor's avatar notoraptor

Wrap op params for theano.gpuarray.blocksparse.GpuSparseBlockGemv.

上级 d7f7854a
...@@ -4,19 +4,19 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -4,19 +4,19 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
PyGpuArrayObject *h, PyArrayObject *inputIdx, PyGpuArrayObject *h, PyArrayObject *inputIdx,
PyArrayObject *outputIdx, PyArrayObject *outputIdx,
PyGpuArrayObject **_out, PyGpuArrayObject **_out,
PyGpuContextObject *ctx) { PARAMS_TYPE* params) {
PyGpuArrayObject *out = *_out; PyGpuArrayObject *out = *_out;
#ifdef INPLACE if (params->inplace) {
Py_XDECREF(out); Py_XDECREF(out);
out = o; out = o;
Py_INCREF(out); Py_INCREF(out);
#else } else {
out = theano_try_copy(out, o); out = theano_try_copy(out, o);
if (out == NULL) { if (out == NULL) {
// Error already set // Error already set
return -1; return -1;
} }
#endif }
gpudata **W_list = NULL; gpudata **W_list = NULL;
gpudata **inp_list = NULL; gpudata **inp_list = NULL;
...@@ -26,7 +26,7 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -26,7 +26,7 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
size_t *offOut = NULL; size_t *offOut = NULL;
int err; int err;
err = gpublas_setup(ctx->ctx); err = gpublas_setup(params->context->ctx);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Can't setup blas"); PyErr_SetString(PyExc_RuntimeError, "Can't setup blas");
return -1; return -1;
......
...@@ -4,8 +4,9 @@ import os ...@@ -4,8 +4,9 @@ import os
import numpy as np import numpy as np
from theano import Apply, tensor from theano import Apply, tensor
from theano.gof import COp from theano.gof import COp, ParamsType
from theano.tensor import discrete_dtypes, as_tensor_variable from theano.tensor import discrete_dtypes, as_tensor_variable
from theano.scalar import bool as bool_t
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
...@@ -25,7 +26,8 @@ class GpuSparseBlockGemv(COp): ...@@ -25,7 +26,8 @@ class GpuSparseBlockGemv(COp):
function for a stable interface. function for a stable interface.
""" """
__props__ = ('inplace',) __props__ = ('inplace',)
params_type = gpu_context_type params_type = ParamsType(inplace=bool_t, context=gpu_context_type)
# NB: DTYPE_INPUT_* is used in C code, so I think we should not set check_input to False.
def __init__(self, inplace=False): def __init__(self, inplace=False):
COp.__init__(self, "blockgemv.c", "APPLY_SPECIFIC(blockgemv)") COp.__init__(self, "blockgemv.c", "APPLY_SPECIFIC(blockgemv)")
...@@ -34,13 +36,7 @@ class GpuSparseBlockGemv(COp): ...@@ -34,13 +36,7 @@ class GpuSparseBlockGemv(COp):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return self.params_type.get_params(self, context=node.inputs[0].type.context)
def get_op_params(self):
if self.inplace:
return [('INPLACE', '1')]
else:
return []
def c_header_dirs(self): def c_header_dirs(self):
return [os.path.dirname(__file__)] return [os.path.dirname(__file__)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论