提交 07068846 authored 作者: Frederic's avatar Frederic

Make IncSubtensor use the new NumPy API.

上级 0be9c072
......@@ -534,13 +534,12 @@ class Subtensor(Op):
return {
"c_prefix": "PyArray",
"strides_mul": 1,
"view_name": "xview"}
}
@staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None,
strides_mul=None,
view_name=None
):
"""
The parameters c_prefix are there to allow reusing this
......@@ -557,9 +556,6 @@ class Subtensor(Op):
if c_prefix is None:
c_prefix = default_args['c_prefix']
if view_name is None:
view_name = default_args['view_name']
#
# two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice
......@@ -633,8 +629,6 @@ class Subtensor(Op):
x, = inputs[:1]
z, = outputs
xview = view_name
rval = """
// The subtensor is created by iterating over the dimensions
......@@ -648,7 +642,7 @@ class Subtensor(Op):
int outer_ii = 0; // current dimension of z
// Argument of the view
char* xview_ptr = (char*) %(c_prefix)s_BYTES(%(x)s);
ssize_t xview_offset = 0;
ssize_t xview_dims[%(view_ndim)s];
ssize_t xview_strides[%(view_ndim)s];
......@@ -720,7 +714,7 @@ class Subtensor(Op):
assert (slicelength <= length);
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
xview_offset += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
%(strides_mul)s;
xview_dims[inner_ii] = slicelength;
xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step;
......@@ -736,7 +730,7 @@ class Subtensor(Op):
{
if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
{
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
xview_offset += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
%(strides_mul)s;
}
else
......@@ -795,9 +789,9 @@ class Subtensor(Op):
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
xview_dims, //PyArray_DIMS(%(x)s),
xview_strides, //PyArray_STRIDES(%(x)s),
xview_ptr, //PyArray_DATA(%(x)s),
xview_dims,
xview_strides,
PyArray_BYTES(%(x)s) + xview_offset,
PyArray_FLAGS(%(x)s),
NULL);
assert (PyArray_NDIM(xview) == %(view_ndim)s);
......@@ -808,8 +802,7 @@ class Subtensor(Op):
""" % locals()
finish_view = """
//Is this update needed? The doc of PyArray_NewFromDescr do not
//tell us the information needed
//This is needed for NumPy 1.5, but not 1.7.2
PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
Py_XDECREF(%(z)s);
Py_INCREF(py_%(x)s);
......@@ -1178,6 +1171,7 @@ class IncSubtensor(Op):
numpy.sum([not isinstance(idx, slice)
for idx in self.idx_list]))
decl = "PyArrayObject * zview = NULL;"
copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = """
......@@ -1201,7 +1195,21 @@ class IncSubtensor(Op):
# On GPU, it takes two steps to make a view
link_zview = self.link_view_array(z, fail)
#Make a first view on the output, as we will write into it.
# get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args()
get_zview = Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
view_ndim=view_ndim,
** helper_args
)
#Make a view on the output, as we will write into it.
build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
......@@ -1212,19 +1220,6 @@ class IncSubtensor(Op):
}
%(link_zview)s;
""" % locals()
# make zview actually a view of %(z)s
helper_args = self.get_helper_c_code_args()
helper_args['view_name'] = 'zview'
get_zview = self.define_set_data() + \
Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
** helper_args
)
copy_into = self.copy_into("zview", y)
......@@ -1245,11 +1240,12 @@ class IncSubtensor(Op):
}
""" % locals()
return (copy_input_if_necessary
+ build_view
+ "{" + get_zview + "}"
+ make_modification
+ "Py_DECREF(zview);"
return (decl +
copy_input_if_necessary +
get_zview +
build_view +
make_modification +
"Py_DECREF(zview);"
)
def do_type_checking(self, node):
......@@ -1299,16 +1295,18 @@ class IncSubtensor(Op):
"""
return """Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * zview =
(PyArrayObject*)PyArray_NewFromDescr(
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
%(x)s->flags,
NULL)""" % locals()
xview_dims, //PyArray_DIMS(%(x)s),
xview_strides, //PyArray_STRIDES(%(x)s),
PyArray_BYTES(%(x)s) + xview_offset, //PyArray_DATA(%(x)s),
PyArray_FLAGS(%(x)s),
NULL);
//This is needed for NumPy 1.5, but not 1.7.2
PyArray_UpdateFlags(zview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
""" % locals()
def get_helper_c_code_args(self):
""" Return a dictionary of arguments to pass to helper_c_code."""
......@@ -1324,11 +1322,6 @@ class IncSubtensor(Op):
"""
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
""" Returns C code used to define any macros used in the
set data argument to the helper C code. """
return ""
def link_view_array(self, x, fail):
""" Returns code to complete making zview a view of x"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论