提交 93c8954c authored 作者: James Bergstra's avatar James Bergstra

cuda - added storage parameter to cuda.filter to reuse storage

上级 5086b5d7
...@@ -60,7 +60,7 @@ class GpuFromHost(Op): ...@@ -60,7 +60,7 @@ class GpuFromHost(Op):
raise TypeError(x) raise TypeError(x)
return Apply(self, [x], [CudaNdarrayType(broadcastable=x.broadcastable)()]) return Apply(self, [x], [CudaNdarrayType(broadcastable=x.broadcastable)()])
def perform(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
z[0] = type_support_filter(theano._asarray(x, dtype='float32'), tuple([0]*x.ndim), 0) z[0] = type_support_filter(theano._asarray(x, dtype='float32'), tuple([0]*x.ndim), 0, z[0])
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
return gz, return gz,
#return [HostFromGpu()(gz)] #return [HostFromGpu()(gz)]
......
...@@ -1377,12 +1377,21 @@ CudaNdarray_Dot(PyObject* _unsed, PyObject * args) ...@@ -1377,12 +1377,21 @@ CudaNdarray_Dot(PyObject* _unsed, PyObject * args)
static PyObject * static PyObject *
filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, strict) filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, strict)
{ {
/*
* TODO: DOC what this function should do in the various cases of
* What is 'strict' supposed to mean in the context of this function?
* What do we do with input that could be interpreted as matching the broadcastable pattern in strict vs. non-strict cases?
*
*/
PyObject *py_data=NULL; PyObject *py_data=NULL;
PyArrayObject * data = NULL; PyArrayObject * data = NULL;
int strict = 0; int strict = 0;
PyObject * broadcastable=NULL; PyObject * broadcastable=NULL;
PyObject * storage=NULL;
CudaNdarray * rval=NULL;
if (!PyArg_ParseTuple(args, "OOi", &py_data, &broadcastable, &strict)) return NULL; //Python object references which are provided to the caller are borrowed references
if (!PyArg_ParseTuple(args, "OOiO", &py_data, &broadcastable, &strict, &storage)) return NULL;
if (!PyTuple_Check(broadcastable)){ if (!PyTuple_Check(broadcastable)){
PyErr_SetString(PyExc_TypeError, "broadcastable arg should be a tuple of int."); PyErr_SetString(PyExc_TypeError, "broadcastable arg should be a tuple of int.");
...@@ -1444,7 +1453,15 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -1444,7 +1453,15 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
return NULL; return NULL;
} }
} }
CudaNdarray * rval = (CudaNdarray*) CudaNdarray_new_null(); if (CudaNdarray_Check(storage))
{
rval = (CudaNdarray*) storage;
Py_INCREF(rval);
}
else
{
rval = (CudaNdarray*) CudaNdarray_new_null();
}
if (CudaNdarray_CopyFromArray(rval, data)) if (CudaNdarray_CopyFromArray(rval, data))
{ {
Py_DECREF(rval); Py_DECREF(rval);
...@@ -1460,7 +1477,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -1460,7 +1477,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
static PyMethodDef module_methods[] = { static PyMethodDef module_methods[] = {
{"dot", CudaNdarray_Dot, METH_VARARGS, "Returns the matrix product of two CudaNdarray arguments."}, {"dot", CudaNdarray_Dot, METH_VARARGS, "Returns the matrix product of two CudaNdarray arguments."},
{"gpu_init", CudaNdarray_gpu_init, METH_VARARGS, "Allow to select the gpu card to use."}, {"gpu_init", CudaNdarray_gpu_init, METH_VARARGS, "Allow to select the gpu card to use."},
{"filter", filter, METH_VARARGS, "no doc!"}, {"filter", filter, METH_VARARGS, "filter(obj, broadcastable, strict, storage) returns a CudaNdarray initialized to obj if it matches the constraints of broadcastable. strict=True prevents any numeric casting. If storage is a CudaNdarray it may be overwritten and used as the return value."},
{NULL, NULL, NULL, NULL} /* Sentinel */ {NULL, NULL, NULL, NULL} /* Sentinel */
}; };
......
...@@ -51,7 +51,7 @@ class CudaNdarrayType(Type): ...@@ -51,7 +51,7 @@ class CudaNdarrayType(Type):
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def filter(self, data, strict=False): def filter(self, data, strict=False):
return cuda.filter(data, self.broadcastable, strict) return cuda.filter(data, self.broadcastable, strict, None)
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
......
...@@ -107,7 +107,7 @@ def float32_shared_constructor(value, name, strict=False, broadcastable=None): ...@@ -107,7 +107,7 @@ def float32_shared_constructor(value, name, strict=False, broadcastable=None):
if broadcastable is None: if broadcastable is None:
broadcastable = (False,) * len(value.shape) broadcastable = (False,) * len(value.shape)
type = CudaNdarrayType(broadcastable=broadcastable) type = CudaNdarrayType(broadcastable=broadcastable)
deviceval = type_support_filter(value, broadcastable, False) deviceval = type_support_filter(value, broadcastable, False, None)
try: try:
rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict) rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict)
except Exception, e: except Exception, e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论