提交 1d6b79f6 authored 作者: notoraptor's avatar notoraptor

Wrap op params for theano.gpuarray.blocksparse.GpuSparseBlockOuter.

上级 c67185bb
...@@ -4,7 +4,7 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x, ...@@ -4,7 +4,7 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x,
PyGpuArrayObject *y, PyArrayObject *xIdx, PyGpuArrayObject *y, PyArrayObject *xIdx,
PyArrayObject *yIdx, PyArrayObject *alpha, PyArrayObject *yIdx, PyArrayObject *alpha,
PyGpuArrayObject **_out, PyGpuArrayObject **_out,
PyGpuContextObject *ctx) { PARAMS_TYPE* params) {
PyGpuArrayObject *out = *_out; PyGpuArrayObject *out = *_out;
gpudata **o_list = NULL; gpudata **o_list = NULL;
gpudata **x_list = NULL; gpudata **x_list = NULL;
...@@ -14,21 +14,21 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x, ...@@ -14,21 +14,21 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x,
size_t *offY = NULL; size_t *offY = NULL;
int err; int err;
err = gpublas_setup(ctx->ctx); err = gpublas_setup(params->context->ctx);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Can't setup blas"); PyErr_SetString(PyExc_RuntimeError, "Can't setup blas");
return -1; return -1;
} }
#ifdef INPLACE if (params->inplace) {
Py_XDECREF(out); Py_XDECREF(out);
out = o; out = o;
Py_INCREF(out); Py_INCREF(out);
#else } else {
out = theano_try_copy(out, o); out = theano_try_copy(out, o);
if (out == NULL) if (out == NULL)
return -1; return -1;
#endif }
size_t maxi = PyGpuArray_DIMS(x)[1]; size_t maxi = PyGpuArray_DIMS(x)[1];
size_t maxj = PyGpuArray_DIMS(y)[1]; size_t maxj = PyGpuArray_DIMS(y)[1];
size_t maxb = PyGpuArray_DIMS(x)[0]; size_t maxb = PyGpuArray_DIMS(x)[0];
......
...@@ -98,7 +98,7 @@ class GpuSparseBlockOuter(COp): ...@@ -98,7 +98,7 @@ class GpuSparseBlockOuter(COp):
of GpuSparseBlockGemv. The gradient is not implemented. of GpuSparseBlockGemv. The gradient is not implemented.
""" """
__props__ = ('inplace',) __props__ = ('inplace',)
params_type = gpu_context_type params_type = ParamsType(inplace=bool_t, context=gpu_context_type)
def __init__(self, inplace=False): def __init__(self, inplace=False):
COp.__init__(self, ["blockger.c"], "APPLY_SPECIFIC(blockger)") COp.__init__(self, ["blockger.c"], "APPLY_SPECIFIC(blockger)")
...@@ -107,13 +107,7 @@ class GpuSparseBlockOuter(COp): ...@@ -107,13 +107,7 @@ class GpuSparseBlockOuter(COp):
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return self.params_type.get_params(self, context=node.inputs[0].type.context)
def get_op_params(self):
if self.inplace:
return [('INPLACE', '1')]
else:
return []
def make_node(self, o, x, y, xIdx, yIdx, alpha=None): def make_node(self, o, x, y, xIdx, yIdx, alpha=None):
ctx = infer_context_name(o, x, y) ctx = infer_context_name(o, x, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论