提交 fee6befb authored 作者: Ian Goodfellow's avatar Ian Goodfellow

IncSubtensor: generalized the allocation of the output view

上级 9728b350
......@@ -2433,6 +2433,10 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
def copy_of_x(self, x):
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
def make_view_buffer(self, x, view_ndim):
return """CudaNdarray* xview = (CudaNdarray*)
CudaNdarray_New(%(view_ndim)s)""" % locals()
def c_code_cache_version(self):
# TODO: cooperate with parent class' C code
return ()
......
......@@ -439,6 +439,10 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args);
static int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self);
#endif
/*
Local Variables:
......
......@@ -4575,20 +4575,14 @@ class IncSubtensor(Op):
%(z)s = %(copy_of_x)s; }
""" % locals()
alloc_view_of_z = self.make_view_buffer(z, view_ndim)
#Make a first view on the output, as we will write into it.
build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(z)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(z)s),
%(view_ndim)s,
PyArray_DIMS(%(z)s),
PyArray_STRIDES(%(z)s),
PyArray_DATA(%(z)s),
%(z)s->flags,
NULL);
%(alloc_view_of_z)s;
if (!xview)
{
%(fail)s;
......@@ -4670,6 +4664,27 @@ class IncSubtensor(Op):
return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def make_view_buffer(x, view_ndim):
"""
x: a string identifying an array to be viewed
view_ndim: a string specifying the number of dimensions
to have in the view
This doesn't need to actually set up the view with the
right indexing; we'll do that manually later.
"""
return """PyArrayObject * xview =
(PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
%(x)s->flags,
NULL)""" % locals()
def infer_shape(self, node, shapes):
return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论