提交 c67185bb authored 作者: notoraptor's avatar notoraptor

Wrap op params for theano.gpuarray.blocksparse.GpuSparseBlockGemv.

上级 d7f7854a
......@@ -4,19 +4,19 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
PyGpuArrayObject *h, PyArrayObject *inputIdx,
PyArrayObject *outputIdx,
PyGpuArrayObject **_out,
PyGpuContextObject *ctx) {
PARAMS_TYPE* params) {
PyGpuArrayObject *out = *_out;
#ifdef INPLACE
Py_XDECREF(out);
out = o;
Py_INCREF(out);
#else
out = theano_try_copy(out, o);
if (out == NULL) {
// Error already set
return -1;
if (params->inplace) {
Py_XDECREF(out);
out = o;
Py_INCREF(out);
} else {
out = theano_try_copy(out, o);
if (out == NULL) {
// Error already set
return -1;
}
}
#endif
gpudata **W_list = NULL;
gpudata **inp_list = NULL;
......@@ -26,7 +26,7 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
size_t *offOut = NULL;
int err;
err = gpublas_setup(ctx->ctx);
err = gpublas_setup(params->context->ctx);
if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Can't setup blas");
return -1;
......
......@@ -4,8 +4,9 @@ import os
import numpy as np
from theano import Apply, tensor
from theano.gof import COp
from theano.gof import COp, ParamsType
from theano.tensor import discrete_dtypes, as_tensor_variable
from theano.scalar import bool as bool_t
from theano.gradient import grad_undefined
......@@ -25,7 +26,8 @@ class GpuSparseBlockGemv(COp):
function for a stable interface.
"""
__props__ = ('inplace',)
params_type = gpu_context_type
params_type = ParamsType(inplace=bool_t, context=gpu_context_type)
# NB: DTYPE_INPUT_* is used in C code, so I think we should not set check_input to False.
def __init__(self, inplace=False):
COp.__init__(self, "blockgemv.c", "APPLY_SPECIFIC(blockgemv)")
......@@ -34,13 +36,7 @@ class GpuSparseBlockGemv(COp):
self.destroy_map = {0: [0]}
def get_params(self, node):
return node.inputs[0].type.context
def get_op_params(self):
if self.inplace:
return [('INPLACE', '1')]
else:
return []
return self.params_type.get_params(self, context=node.inputs[0].type.context)
def c_header_dirs(self):
return [os.path.dirname(__file__)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论