提交 93c8954c authored 作者: James Bergstra's avatar James Bergstra

cuda - added storage parameter to cuda.filter to reuse storage

上级 5086b5d7
......@@ -60,7 +60,7 @@ class GpuFromHost(Op):
raise TypeError(x)
return Apply(self, [x], [CudaNdarrayType(broadcastable=x.broadcastable)()])
def perform(self, node, (x,), (z,)):
z[0] = type_support_filter(theano._asarray(x, dtype='float32'), tuple([0]*x.ndim), 0)
z[0] = type_support_filter(theano._asarray(x, dtype='float32'), tuple([0]*x.ndim), 0, z[0])
def grad(self, inputs, (gz,)):
return gz,
#return [HostFromGpu()(gz)]
......
......@@ -1377,12 +1377,21 @@ CudaNdarray_Dot(PyObject* _unsed, PyObject * args)
static PyObject *
filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, strict)
{
/*
* TODO: DOC what this function should do in the various cases of
* What is 'strict' supposed to mean in the context of this function?
* What do we do with input that could be interpreted as matching the broadcastable pattern in strict vs. non-strict cases?
*
*/
PyObject *py_data=NULL;
PyArrayObject * data = NULL;
int strict = 0;
PyObject * broadcastable=NULL;
PyObject * storage=NULL;
CudaNdarray * rval=NULL;
if (!PyArg_ParseTuple(args, "OOi", &py_data, &broadcastable, &strict)) return NULL;
//Python object references which are provided to the caller are borrowed references
if (!PyArg_ParseTuple(args, "OOiO", &py_data, &broadcastable, &strict, &storage)) return NULL;
if (!PyTuple_Check(broadcastable)){
PyErr_SetString(PyExc_TypeError, "broadcastable arg should be a tuple of int.");
......@@ -1444,7 +1453,15 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
return NULL;
}
}
CudaNdarray * rval = (CudaNdarray*) CudaNdarray_new_null();
if (CudaNdarray_Check(storage))
{
rval = (CudaNdarray*) storage;
Py_INCREF(rval);
}
else
{
rval = (CudaNdarray*) CudaNdarray_new_null();
}
if (CudaNdarray_CopyFromArray(rval, data))
{
Py_DECREF(rval);
......@@ -1460,7 +1477,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
static PyMethodDef module_methods[] = {
{"dot", CudaNdarray_Dot, METH_VARARGS, "Returns the matrix product of two CudaNdarray arguments."},
{"gpu_init", CudaNdarray_gpu_init, METH_VARARGS, "Allow to select the gpu card to use."},
{"filter", filter, METH_VARARGS, "no doc!"},
{"filter", filter, METH_VARARGS, "filter(obj, broadcastable, strict, storage) returns a CudaNdarray initialized to obj if it matches the constraints of broadcastable. strict=True prevents any numeric casting. If storage is a CudaNdarray it may be overwritten and used as the return value."},
{NULL, NULL, NULL, NULL} /* Sentinel */
};
......
......@@ -51,7 +51,7 @@ class CudaNdarrayType(Type):
self.dtype_specs() # error checking is done there
def filter(self, data, strict=False):
return cuda.filter(data, self.broadcastable, strict)
return cuda.filter(data, self.broadcastable, strict, None)
@staticmethod
def values_eq(a, b):
......
......@@ -107,7 +107,7 @@ def float32_shared_constructor(value, name, strict=False, broadcastable=None):
if broadcastable is None:
broadcastable = (False,) * len(value.shape)
type = CudaNdarrayType(broadcastable=broadcastable)
deviceval = type_support_filter(value, broadcastable, False)
deviceval = type_support_filter(value, broadcastable, False, None)
try:
rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict)
except Exception, e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论