提交 fee6befb authored 作者: Ian Goodfellow's avatar Ian Goodfellow

IncSubtensor: generalized the allocation of the output view

上级 9728b350
...@@ -2433,6 +2433,10 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2433,6 +2433,10 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
def copy_of_x(self, x): def copy_of_x(self, x):
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals() return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
def make_view_buffer(self, x, view_ndim):
return """CudaNdarray* xview = (CudaNdarray*)
CudaNdarray_New(%(view_ndim)s)""" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
# TODO: cooperate with parent class' C code # TODO: cooperate with parent class' C code
return () return ()
......
...@@ -439,6 +439,10 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args); ...@@ -439,6 +439,10 @@ CudaNdarray_TakeFrom(CudaNdarray * self, PyObject *args);
static int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self); static int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self);
#endif #endif
/* /*
Local Variables: Local Variables:
......
...@@ -4575,20 +4575,14 @@ class IncSubtensor(Op): ...@@ -4575,20 +4575,14 @@ class IncSubtensor(Op):
%(z)s = %(copy_of_x)s; } %(z)s = %(copy_of_x)s; }
""" % locals() """ % locals()
alloc_view_of_z = self.make_view_buffer(z, view_ndim)
#Make a first view on the output, as we will write into it. #Make a first view on the output, as we will write into it.
build_view = """ build_view = """
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(z)s)); Py_INCREF(PyArray_DESCR(%(z)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr( %(alloc_view_of_z)s;
&PyArray_Type,
PyArray_DESCR(%(z)s),
%(view_ndim)s,
PyArray_DIMS(%(z)s),
PyArray_STRIDES(%(z)s),
PyArray_DATA(%(z)s),
%(z)s->flags,
NULL);
if (!xview) if (!xview)
{ {
%(fail)s; %(fail)s;
...@@ -4670,6 +4664,27 @@ class IncSubtensor(Op): ...@@ -4670,6 +4664,27 @@ class IncSubtensor(Op):
return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0, return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)""" % locals() NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def make_view_buffer(x, view_ndim):
"""
x: a string identifying an array to be viewed
view_ndim: a string specifying the number of dimensions
to have in the view
This doesn't need to actually set up the view with the
right indexing; we'll do that manually later.
"""
return """PyArrayObject * xview =
(PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
%(x)s->flags,
NULL)""" % locals()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论