提交 07068846 authored 作者: Frederic's avatar Frederic

Make IncSubtensor use the new NumPy API.

上级 0be9c072
...@@ -534,13 +534,12 @@ class Subtensor(Op): ...@@ -534,13 +534,12 @@ class Subtensor(Op):
return { return {
"c_prefix": "PyArray", "c_prefix": "PyArray",
"strides_mul": 1, "strides_mul": 1,
"view_name": "xview"} }
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim, def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None, c_prefix=None,
strides_mul=None, strides_mul=None,
view_name=None
): ):
""" """
The parameters c_prefix are there to allow reusing this The parameters c_prefix are there to allow reusing this
...@@ -557,9 +556,6 @@ class Subtensor(Op): ...@@ -557,9 +556,6 @@ class Subtensor(Op):
if c_prefix is None: if c_prefix is None:
c_prefix = default_args['c_prefix'] c_prefix = default_args['c_prefix']
if view_name is None:
view_name = default_args['view_name']
# #
# two arrays are created in C code: # two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice # is_slice: len == ndim, 0 means int, 1 means slice
...@@ -633,8 +629,6 @@ class Subtensor(Op): ...@@ -633,8 +629,6 @@ class Subtensor(Op):
x, = inputs[:1] x, = inputs[:1]
z, = outputs z, = outputs
xview = view_name
rval = """ rval = """
// The subtensor is created by iterating over the dimensions // The subtensor is created by iterating over the dimensions
...@@ -648,7 +642,7 @@ class Subtensor(Op): ...@@ -648,7 +642,7 @@ class Subtensor(Op):
int outer_ii = 0; // current dimension of z int outer_ii = 0; // current dimension of z
// Argument of the view // Argument of the view
char* xview_ptr = (char*) %(c_prefix)s_BYTES(%(x)s); ssize_t xview_offset = 0;
ssize_t xview_dims[%(view_ndim)s]; ssize_t xview_dims[%(view_ndim)s];
ssize_t xview_strides[%(view_ndim)s]; ssize_t xview_strides[%(view_ndim)s];
...@@ -720,7 +714,7 @@ class Subtensor(Op): ...@@ -720,7 +714,7 @@ class Subtensor(Op):
assert (slicelength <= length); assert (slicelength <= length);
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start * xview_offset += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
%(strides_mul)s; %(strides_mul)s;
xview_dims[inner_ii] = slicelength; xview_dims[inner_ii] = slicelength;
xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step; xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step;
...@@ -736,7 +730,7 @@ class Subtensor(Op): ...@@ -736,7 +730,7 @@ class Subtensor(Op):
{ {
if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii]) if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
{ {
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx * xview_offset += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
%(strides_mul)s; %(strides_mul)s;
} }
else else
...@@ -795,9 +789,9 @@ class Subtensor(Op): ...@@ -795,9 +789,9 @@ class Subtensor(Op):
&PyArray_Type, &PyArray_Type,
PyArray_DESCR(%(x)s), PyArray_DESCR(%(x)s),
%(view_ndim)s, %(view_ndim)s,
xview_dims, //PyArray_DIMS(%(x)s), xview_dims,
xview_strides, //PyArray_STRIDES(%(x)s), xview_strides,
xview_ptr, //PyArray_DATA(%(x)s), PyArray_BYTES(%(x)s) + xview_offset,
PyArray_FLAGS(%(x)s), PyArray_FLAGS(%(x)s),
NULL); NULL);
assert (PyArray_NDIM(xview) == %(view_ndim)s); assert (PyArray_NDIM(xview) == %(view_ndim)s);
...@@ -808,8 +802,7 @@ class Subtensor(Op): ...@@ -808,8 +802,7 @@ class Subtensor(Op):
""" % locals() """ % locals()
finish_view = """ finish_view = """
//Is this update needed? The doc of PyArray_NewFromDescr do not //This is needed for NumPy 1.5, but not 1.7.2
//tell us the information needed
PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS); PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
Py_INCREF(py_%(x)s); Py_INCREF(py_%(x)s);
...@@ -1178,6 +1171,7 @@ class IncSubtensor(Op): ...@@ -1178,6 +1171,7 @@ class IncSubtensor(Op):
numpy.sum([not isinstance(idx, slice) numpy.sum([not isinstance(idx, slice)
for idx in self.idx_list])) for idx in self.idx_list]))
decl = "PyArrayObject * zview = NULL;"
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = """ copy_input_if_necessary = """
...@@ -1201,7 +1195,21 @@ class IncSubtensor(Op): ...@@ -1201,7 +1195,21 @@ class IncSubtensor(Op):
# On GPU, it takes two steps to make a view # On GPU, it takes two steps to make a view
link_zview = self.link_view_array(z, fail) link_zview = self.link_view_array(z, fail)
#Make a first view on the output, as we will write into it. # get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args()
get_zview = Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
view_ndim=view_ndim,
** helper_args
)
#Make a view on the output, as we will write into it.
build_view = """ build_view = """
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
...@@ -1212,19 +1220,6 @@ class IncSubtensor(Op): ...@@ -1212,19 +1220,6 @@ class IncSubtensor(Op):
} }
%(link_zview)s; %(link_zview)s;
""" % locals() """ % locals()
# make zview actually a view of %(z)s
helper_args = self.get_helper_c_code_args()
helper_args['view_name'] = 'zview'
get_zview = self.define_set_data() + \
Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
** helper_args
)
copy_into = self.copy_into("zview", y) copy_into = self.copy_into("zview", y)
...@@ -1245,11 +1240,12 @@ class IncSubtensor(Op): ...@@ -1245,11 +1240,12 @@ class IncSubtensor(Op):
} }
""" % locals() """ % locals()
return (copy_input_if_necessary return (decl +
+ build_view copy_input_if_necessary +
+ "{" + get_zview + "}" get_zview +
+ make_modification build_view +
+ "Py_DECREF(zview);" make_modification +
"Py_DECREF(zview);"
) )
def do_type_checking(self, node): def do_type_checking(self, node):
...@@ -1299,16 +1295,18 @@ class IncSubtensor(Op): ...@@ -1299,16 +1295,18 @@ class IncSubtensor(Op):
""" """
return """Py_INCREF(PyArray_DESCR(%(x)s)); return """Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * zview = zview = (PyArrayObject*)PyArray_NewFromDescr(
(PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type, &PyArray_Type,
PyArray_DESCR(%(x)s), PyArray_DESCR(%(x)s),
%(view_ndim)s, %(view_ndim)s,
PyArray_DIMS(%(x)s), xview_dims, //PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s), xview_strides, //PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s), PyArray_BYTES(%(x)s) + xview_offset, //PyArray_DATA(%(x)s),
%(x)s->flags, PyArray_FLAGS(%(x)s),
NULL)""" % locals() NULL);
//This is needed for NumPy 1.5, but not 1.7.2
PyArray_UpdateFlags(zview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
""" % locals()
def get_helper_c_code_args(self): def get_helper_c_code_args(self):
""" Return a dictionary of arguments to pass to helper_c_code.""" """ Return a dictionary of arguments to pass to helper_c_code."""
...@@ -1324,11 +1322,6 @@ class IncSubtensor(Op): ...@@ -1324,11 +1322,6 @@ class IncSubtensor(Op):
""" """
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals() return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
""" Returns C code used to define any macros used in the
set data argument to the helper C code. """
return ""
def link_view_array(self, x, fail): def link_view_array(self, x, fail):
""" Returns code to complete making zview a view of x""" """ Returns code to complete making zview a view of x"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论